96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
from pathlib import Path
|
||
import sys
|
||
import torch
|
||
import numpy
|
||
from PIL import Image, ImageFile
|
||
import matplotlib.pyplot as plt
|
||
from mnist import CNN
|
||
|
||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||
import gpu_utils
|
||
|
||
|
||
class PredictResult:
|
||
|
||
possibilities: torch.Tensor
|
||
|
||
def __init__(self, possibilities: torch.Tensor):
|
||
self.possibilities = possibilities
|
||
|
||
def chosen_number(self) -> int:
|
||
"""获取最终选定的数字"""
|
||
# 依然是找最大的那个index
|
||
_, prediction = self.possibilities.max(1)
|
||
return prediction.item()
|
||
|
||
def number_possibilities(self) -> list[float]:
|
||
"""获取每个数字出现的概率"""
|
||
return list(self.possibilities[0][i].item() for i in range(10))
|
||
|
||
class Predictor:
|
||
device: torch.device
|
||
cnn: CNN
|
||
|
||
def __init__(self):
|
||
self.device = gpu_utils.get_gpu_device()
|
||
self.cnn = CNN().to(self.device)
|
||
|
||
# 加载保存好的模型参数
|
||
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
|
||
self.cnn.load_state_dict(torch.load(file_path))
|
||
|
||
def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult:
|
||
input = torch.Tensor(image).float().to(self.device)
|
||
assert(input.dim() == 2)
|
||
assert(input.size(0) == 28)
|
||
assert(input.size(1) == 28)
|
||
|
||
# 为了满足要求,要在第一维度挤出2下
|
||
# 一次是灰度通道,一次是批次。
|
||
# 相当于batch size = 1的计算
|
||
input = input.unsqueeze(0).unsqueeze(0)
|
||
|
||
# 预测
|
||
with torch.no_grad():
|
||
output = self.cnn(input)
|
||
return PredictResult(output)
|
||
|
||
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
|
||
# 确保图像为灰度图像,然后转换为numpy数组。
|
||
# 注意这里的numpy数组是只读的,所以要先拷贝一份
|
||
grayscale_image = image.convert('L')
|
||
numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True)
|
||
# 转换到Tensor,设置dtype并传到GPU上
|
||
data = torch.from_numpy(numpy_data).float().to(self.device)
|
||
# 归一化到255,又因为图像输入是白底黑字,需要做转换。
|
||
data.div_(255.0).sub_(1).mul_(-1)
|
||
|
||
# 同理,挤出维度并预测
|
||
input = data.unsqueeze(0).unsqueeze(0)
|
||
with torch.no_grad():
|
||
output = self.cnn(input)
|
||
return PredictResult(output)
|
||
|
||
def main():
|
||
predictor = Predictor()
|
||
|
||
# 遍历测试目录中的所有图片,并处理。
|
||
test_dir = Path(__file__).resolve().parent.parent / 'test_images'
|
||
for image_path in test_dir.glob('*.png'):
|
||
if image_path.is_file():
|
||
print(f'Predicting {image_path} ...')
|
||
image = Image.open(image_path)
|
||
rv = predictor.predict_image(image)
|
||
|
||
print(f'Predict digit: {rv.chosen_number()}')
|
||
plt.figure(f'Image - {image_path}')
|
||
plt.imshow(image)
|
||
plt.axis('on')
|
||
plt.title(f'Predict digit: {rv.chosen_number()}')
|
||
plt.show()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|