finish exp3 predict code
This commit is contained in:
@@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||
from PIL import Image, ImageFile
|
||||
import matplotlib.pyplot as plt
|
||||
from model import Cnn
|
||||
import settings
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||
import gpu_utils
|
||||
@@ -53,8 +54,7 @@ class Predictor:
|
||||
self.model = Cnn().to(self.device)
|
||||
|
||||
# 加载保存好的模型参数
|
||||
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
|
||||
self.model.load_state_dict(torch.load(file_path))
|
||||
self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH))
|
||||
|
||||
def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user