1
0
Files
ai-school/exp2/modified/predict.py

96 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import torch
import numpy
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
from mnist import CNN
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
class PredictResult:
possibilities: torch.Tensor
def __init__(self, possibilities: torch.Tensor):
self.possibilities = possibilities
def chosen_number(self) -> int:
"""获取最终选定的数字"""
# 依然是找最大的那个index
_, prediction = self.possibilities.max(1)
return prediction.item()
def number_possibilities(self) -> list[float]:
"""获取每个数字出现的概率"""
return list(self.possibilities[0][i].item() for i in range(10))
class Predictor:
device: torch.device
cnn: CNN
def __init__(self):
self.device = gpu_utils.get_gpu_device()
self.cnn = CNN().to(self.device)
# 加载保存好的模型参数
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
self.cnn.load_state_dict(torch.load(file_path))
def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult:
input = torch.Tensor(image).float().to(self.device)
assert(input.dim() == 2)
assert(input.size(0) == 28)
assert(input.size(1) == 28)
# 为了满足要求要在第一维度挤出2下
# 一次是灰度通道,一次是批次。
# 相当于batch size = 1的计算
input = input.unsqueeze(0).unsqueeze(0)
# 预测
with torch.no_grad():
output = self.cnn(input)
return PredictResult(output)
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
# 确保图像为灰度图像然后转换为numpy数组。
# 注意这里的numpy数组是只读的所以要先拷贝一份
grayscale_image = image.convert('L')
numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True)
# 转换到Tensor设置dtype并传到GPU上
data = torch.from_numpy(numpy_data).float().to(self.device)
# 归一化到255又因为图像输入是白底黑字需要做转换。
data.div_(255.0).sub_(1).mul_(-1)
# 同理,挤出维度并预测
input = data.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
output = self.cnn(input)
return PredictResult(output)
def main():
predictor = Predictor()
# 遍历测试目录中的所有图片,并处理。
test_dir = Path(__file__).resolve().parent.parent / 'test_images'
for image_path in test_dir.glob('*.png'):
if image_path.is_file():
print(f'Predicting {image_path} ...')
image = Image.open(image_path)
rv = predictor.predict_image(image)
print(f'Predict digit: {rv.chosen_number()}')
plt.figure(f'Image - {image_path}')
plt.imshow(image)
plt.axis('on')
plt.title(f'Predict digit: {rv.chosen_number()}')
plt.show()
if __name__ == "__main__":
main()