2025-11-24 21:02:44 +08:00
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
import sys
|
2025-11-30 16:24:32 +08:00
|
|
|
|
import numpy
|
2025-11-30 22:01:56 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn.functional as F
|
2025-11-30 16:24:32 +08:00
|
|
|
|
from PIL import Image, ImageFile
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
2025-11-30 22:01:56 +08:00
|
|
|
|
from model import Cnn
|
2025-12-06 19:56:55 +08:00
|
|
|
|
import settings
|
2025-11-24 21:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
2025-12-02 23:07:27 +08:00
|
|
|
|
import gpu_utils
|
2025-11-24 21:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PredictResult:
|
2025-11-30 22:01:56 +08:00
|
|
|
|
"""预测的结果"""
|
2025-11-24 21:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
possibilities: torch.Tensor
|
2025-12-02 23:07:27 +08:00
|
|
|
|
"""每个数字不同的概率"""
|
2025-11-24 21:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
def __init__(self, possibilities: torch.Tensor):
|
2025-12-02 23:07:27 +08:00
|
|
|
|
"""
|
|
|
|
|
|
创建预测结果。
|
|
|
|
|
|
|
|
|
|
|
|
:param possibilities: 传入的tensor表示每个数字不同的概率,是经过softmax后的数值。
|
|
|
|
|
|
其shape为二维。dim 0为batch,应当只有一维;dim 1为每个数字对应的概率。
|
|
|
|
|
|
"""
|
2025-11-24 21:02:44 +08:00
|
|
|
|
self.possibilities = possibilities
|
|
|
|
|
|
|
|
|
|
|
|
def chosen_number(self) -> int:
|
2025-12-03 09:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
|
获取最终选定的数字
|
|
|
|
|
|
|
|
|
|
|
|
:return: 以当前概率分布,推测的最终数字。
|
|
|
|
|
|
"""
|
2025-12-02 23:07:27 +08:00
|
|
|
|
# 输出出来是10个数字各自的可能性,所以要选取最高可能性的那个对比
|
|
|
|
|
|
# 在dim=1上找最大的那个,就选那个。dim=0是批次所以不管他。
|
|
|
|
|
|
return self.possibilities.argmax(1).item()
|
2025-11-24 21:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
def number_possibilities(self) -> list[float]:
|
2025-12-02 23:07:27 +08:00
|
|
|
|
"""
|
|
|
|
|
|
获取每个数字出现的概率
|
|
|
|
|
|
|
|
|
|
|
|
:return: 返回一个具有10个元素的列表,列表的每一项表示当前index所代表数字的概率。
|
|
|
|
|
|
"""
|
2025-11-24 21:02:44 +08:00
|
|
|
|
return list(self.possibilities[0][i].item() for i in range(10))
|
|
|
|
|
|
|
|
|
|
|
|
class Predictor:
|
|
|
|
|
|
device: torch.device
|
2025-11-30 22:01:56 +08:00
|
|
|
|
model: Cnn
|
2025-11-24 21:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
2025-12-02 23:07:27 +08:00
|
|
|
|
self.device = gpu_utils.get_gpu_device()
|
2025-11-30 22:01:56 +08:00
|
|
|
|
self.model = Cnn().to(self.device)
|
2025-11-24 21:02:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 加载保存好的模型参数
|
2025-12-06 19:56:55 +08:00
|
|
|
|
self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH))
|
2025-11-30 22:01:56 +08:00
|
|
|
|
|
2025-12-03 09:39:41 +08:00
|
|
|
|
def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
|
2025-11-30 22:01:56 +08:00
|
|
|
|
"""
|
|
|
|
|
|
其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor,然后传递给此函数进行实际预测。
|
|
|
|
|
|
|
|
|
|
|
|
:param in_data: 传入的tensor,该tensor的shape必须是28x28,dtype为float32。
|
2025-12-03 09:39:41 +08:00
|
|
|
|
:return: 预测结果。
|
2025-11-30 22:01:56 +08:00
|
|
|
|
"""
|
|
|
|
|
|
# 上传tensor到GPU
|
|
|
|
|
|
in_data = in_data.to(self.device)
|
|
|
|
|
|
# 为了满足要求,要在第一维度挤出2下
|
|
|
|
|
|
# 一次是灰度通道,一次是批次。
|
|
|
|
|
|
# 相当于batch size = 1的计算
|
|
|
|
|
|
in_data = in_data.unsqueeze(0).unsqueeze(0)
|
|
|
|
|
|
# 开始预测,由于模型输出的是没有softmax的数值,因此最后还需要softmax一下,
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
out_data = self.model(in_data)
|
|
|
|
|
|
out_data = F.softmax(out_data, dim=-1)
|
|
|
|
|
|
return PredictResult(out_data)
|
|
|
|
|
|
|
2025-11-24 21:02:44 +08:00
|
|
|
|
|
2025-11-30 16:24:32 +08:00
|
|
|
|
def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult:
|
2025-12-02 23:07:27 +08:00
|
|
|
|
"""
|
|
|
|
|
|
以sketchpad的数据进行预测。
|
|
|
|
|
|
|
2025-12-03 09:39:41 +08:00
|
|
|
|
:param image: 该列表的shape必须为28x28。
|
|
|
|
|
|
:return: 预测结果。
|
2025-12-02 23:07:27 +08:00
|
|
|
|
"""
|
2025-12-06 13:10:02 +08:00
|
|
|
|
input = torch.tensor(image, dtype=torch.float32)
|
2025-11-24 21:02:44 +08:00
|
|
|
|
assert(input.dim() == 2)
|
|
|
|
|
|
assert(input.size(0) == 28)
|
|
|
|
|
|
assert(input.size(1) == 28)
|
|
|
|
|
|
|
2025-12-03 09:39:41 +08:00
|
|
|
|
return self.__predict_tensor(input)
|
2025-11-30 16:24:32 +08:00
|
|
|
|
|
|
|
|
|
|
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
|
2025-12-02 23:07:27 +08:00
|
|
|
|
"""
|
|
|
|
|
|
以Pillow图像的数据进行预测。
|
|
|
|
|
|
|
|
|
|
|
|
:param image: Pillow图像数据。该图像必须为28x28大小。
|
2025-12-03 09:39:41 +08:00
|
|
|
|
:return: 预测结果。
|
2025-12-02 23:07:27 +08:00
|
|
|
|
"""
|
|
|
|
|
|
# 确保图像为灰度图像,以及宽高合适
|
2025-11-30 16:24:32 +08:00
|
|
|
|
grayscale_image = image.convert('L')
|
2025-12-02 23:07:27 +08:00
|
|
|
|
assert(grayscale_image.width == 28)
|
|
|
|
|
|
assert(grayscale_image.height == 28)
|
|
|
|
|
|
# 转换为numpy数组。注意这里的numpy数组是只读的,所以要先拷贝一份
|
2025-11-30 16:24:32 +08:00
|
|
|
|
numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True)
|
2025-11-30 22:01:56 +08:00
|
|
|
|
# 转换到Tensor,设置dtype
|
|
|
|
|
|
data = torch.from_numpy(numpy_data).float()
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 归一化到255,又因为图像输入是白底黑字,需要做转换。
|
|
|
|
|
|
data.div_(255.0).sub_(1).mul_(-1)
|
|
|
|
|
|
|
2025-12-03 09:39:41 +08:00
|
|
|
|
return self.__predict_tensor(data)
|
2025-11-30 16:24:32 +08:00
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
predictor = Predictor()
|
|
|
|
|
|
|
|
|
|
|
|
# 遍历测试目录中的所有图片,并处理。
|
|
|
|
|
|
test_dir = Path(__file__).resolve().parent.parent / 'test_images'
|
|
|
|
|
|
for image_path in test_dir.glob('*.png'):
|
|
|
|
|
|
if image_path.is_file():
|
|
|
|
|
|
print(f'Predicting {image_path} ...')
|
|
|
|
|
|
image = Image.open(image_path)
|
|
|
|
|
|
rv = predictor.predict_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
print(f'Predict digit: {rv.chosen_number()}')
|
|
|
|
|
|
plt.figure(f'Image - {image_path}')
|
|
|
|
|
|
plt.imshow(image)
|
|
|
|
|
|
plt.axis('on')
|
|
|
|
|
|
plt.title(f'Predict digit: {rv.chosen_number()}')
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2025-12-02 23:07:27 +08:00
|
|
|
|
gpu_utils.print_gpu_availability()
|
2025-11-30 16:24:32 +08:00
|
|
|
|
main()
|
|
|
|
|
|
|