2025-11-24 21:02:44 +08:00
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
import sys
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from train import CNN
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
|
|
|
|
|
import gpu_utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PredictResult:
|
|
|
|
|
|
|
|
|
|
|
|
possibilities: torch.Tensor
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, possibilities: torch.Tensor):
|
|
|
|
|
|
self.possibilities = possibilities
|
|
|
|
|
|
|
|
|
|
|
|
def chosen_number(self) -> int:
|
|
|
|
|
|
"""获取最终选定的数字"""
|
|
|
|
|
|
# 依然是找最大的那个index
|
|
|
|
|
|
_, prediction = self.possibilities.max(1)
|
|
|
|
|
|
return prediction.item()
|
|
|
|
|
|
|
|
|
|
|
|
def number_possibilities(self) -> list[float]:
|
|
|
|
|
|
"""获取每个数字出现的概率"""
|
|
|
|
|
|
return list(self.possibilities[0][i].item() for i in range(10))
|
|
|
|
|
|
|
|
|
|
|
|
class Predictor:
|
|
|
|
|
|
device: torch.device
|
|
|
|
|
|
cnn: CNN
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.device = gpu_utils.get_gpu_device()
|
|
|
|
|
|
self.cnn = CNN().to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
# 加载保存好的模型参数
|
|
|
|
|
|
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
|
|
|
|
|
|
self.cnn.load_state_dict(torch.load(file_path))
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, image: list[list[bool]]) -> PredictResult:
|
|
|
|
|
|
input = torch.Tensor(image).float().to(self.device)
|
|
|
|
|
|
assert(input.dim() == 2)
|
|
|
|
|
|
assert(input.size(0) == 28)
|
|
|
|
|
|
assert(input.size(1) == 28)
|
|
|
|
|
|
|
|
|
|
|
|
# 为了满足要求,要在第一维度挤出2下
|
|
|
|
|
|
# 一次是灰度通道,一次是批次。
|
|
|
|
|
|
# 相当于batch size = 1的计算
|
|
|
|
|
|
input = input.unsqueeze(0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
# 预测
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
output = self.cnn(input)
|
|
|
|
|
|
return PredictResult(output)
|
|
|
|
|
|
|