1
0

update for change of exp2 and add exp3

This commit is contained in:
2025-11-30 16:24:32 +08:00
parent af890d899e
commit 48fcdfcc80
17 changed files with 859 additions and 124 deletions

View File

@@ -1,7 +1,10 @@
from pathlib import Path
import sys
import torch
from train import CNN
import numpy
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
from mnist import CNN
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
@@ -36,7 +39,7 @@ class Predictor:
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
self.cnn.load_state_dict(torch.load(file_path))
def predict(self, image: list[list[bool]]) -> PredictResult:
def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult:
input = torch.Tensor(image).float().to(self.device)
assert(input.dim() == 2)
assert(input.size(0) == 28)
@@ -51,4 +54,42 @@ class Predictor:
with torch.no_grad():
output = self.cnn(input)
return PredictResult(output)
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
# 确保图像为灰度图像然后转换为numpy数组。
# 注意这里的numpy数组是只读的所以要先拷贝一份
grayscale_image = image.convert('L')
numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True)
# 转换到Tensor设置dtype并传到GPU上
data = torch.from_numpy(numpy_data).float().to(self.device)
# 归一化到255又因为图像输入是白底黑字需要做转换。
data.div_(255.0).sub_(1).mul_(-1)
# 同理,挤出维度并预测
input = data.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
output = self.cnn(input)
return PredictResult(output)
def main():
predictor = Predictor()
# 遍历测试目录中的所有图片,并处理。
test_dir = Path(__file__).resolve().parent.parent / 'test_images'
for image_path in test_dir.glob('*.png'):
if image_path.is_file():
print(f'Predicting {image_path} ...')
image = Image.open(image_path)
rv = predictor.predict_image(image)
print(f'Predict digit: {rv.chosen_number()}')
plt.figure(f'Image - {image_path}')
plt.imshow(image)
plt.axis('on')
plt.title(f'Predict digit: {rv.chosen_number()}')
plt.show()
if __name__ == "__main__":
main()