from pathlib import Path import sys import torch import numpy from PIL import Image, ImageFile import matplotlib.pyplot as plt from mnist import CNN sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) import gpu_utils class PredictResult: possibilities: torch.Tensor def __init__(self, possibilities: torch.Tensor): self.possibilities = possibilities def chosen_number(self) -> int: """获取最终选定的数字""" # 依然是找最大的那个index _, prediction = self.possibilities.max(1) return prediction.item() def number_possibilities(self) -> list[float]: """获取每个数字出现的概率""" return list(self.possibilities[0][i].item() for i in range(10)) class Predictor: device: torch.device cnn: CNN def __init__(self): self.device = gpu_utils.get_gpu_device() self.cnn = CNN().to(self.device) # 加载保存好的模型参数 file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth' self.cnn.load_state_dict(torch.load(file_path)) def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult: input = torch.Tensor(image).float().to(self.device) assert(input.dim() == 2) assert(input.size(0) == 28) assert(input.size(1) == 28) # 为了满足要求,要在第一维度挤出2下 # 一次是灰度通道,一次是批次。 # 相当于batch size = 1的计算 input = input.unsqueeze(0).unsqueeze(0) # 预测 with torch.no_grad(): output = self.cnn(input) return PredictResult(output) def predict_image(self, image: ImageFile.ImageFile) -> PredictResult: # 确保图像为灰度图像,然后转换为numpy数组。 # 注意这里的numpy数组是只读的,所以要先拷贝一份 grayscale_image = image.convert('L') numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True) # 转换到Tensor,设置dtype并传到GPU上 data = torch.from_numpy(numpy_data).float().to(self.device) # 归一化到255,又因为图像输入是白底黑字,需要做转换。 data.div_(255.0).sub_(1).mul_(-1) # 同理,挤出维度并预测 input = data.unsqueeze(0).unsqueeze(0) with torch.no_grad(): output = self.cnn(input) return PredictResult(output) def main(): predictor = Predictor() # 遍历测试目录中的所有图片,并处理。 test_dir = Path(__file__).resolve().parent.parent / 'test_images' for image_path in test_dir.glob('*.png'): if image_path.is_file(): print(f'Predicting {image_path} ...') image = Image.open(image_path) rv = predictor.predict_image(image) print(f'Predict digit: {rv.chosen_number()}') plt.figure(f'Image - {image_path}') plt.imshow(image) plt.axis('on') plt.title(f'Predict digit: {rv.chosen_number()}') plt.show() if __name__ == "__main__": main()