1
0

finish exp3 predict code

This commit is contained in:
2025-12-06 19:56:55 +08:00
parent 45b60b269f
commit ee18246d51
4 changed files with 138 additions and 2 deletions

View File

@@ -6,6 +6,7 @@ import torch.nn.functional as F
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
from model import Cnn
import settings
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
@@ -53,8 +54,7 @@ class Predictor:
self.model = Cnn().to(self.device)
# 加载保存好的模型参数
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
self.model.load_state_dict(torch.load(file_path))
self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH))
def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
"""