1
0

fix exp2 pytorch rewrite fatal train issue

This commit is contained in:
2025-11-30 22:01:56 +08:00
parent 48fcdfcc80
commit 43b807679f
13 changed files with 738 additions and 112 deletions

View File

@@ -1,18 +1,21 @@
from pathlib import Path
import sys
import torch
import numpy
import torch
import torch.nn.functional as F
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
from mnist import CNN
from model import Cnn
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
import pytorch_gpu_utils
class PredictResult:
"""预测的结果"""
possibilities: torch.Tensor
"""预测结果是每个数字不同的概率是经过softmax后的数值"""
def __init__(self, possibilities: torch.Tensor):
self.possibilities = possibilities
@@ -29,47 +32,54 @@ class PredictResult:
class Predictor:
device: torch.device
cnn: CNN
model: Cnn
def __init__(self):
self.device = gpu_utils.get_gpu_device()
self.cnn = CNN().to(self.device)
self.device = pytorch_gpu_utils.get_gpu_device()
self.model = Cnn().to(self.device)
# 加载保存好的模型参数
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
self.cnn.load_state_dict(torch.load(file_path))
self.model.load_state_dict(torch.load(file_path))
def generic_predict(self, in_data: torch.Tensor) -> PredictResult:
"""
其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor然后传递给此函数进行实际预测。
:param in_data: 传入的tensor该tensor的shape必须是28x28dtype为float32。
"""
# 上传tensor到GPU
in_data = in_data.to(self.device)
# 为了满足要求要在第一维度挤出2下
# 一次是灰度通道,一次是批次。
# 相当于batch size = 1的计算
in_data = in_data.unsqueeze(0).unsqueeze(0)
# 开始预测由于模型输出的是没有softmax的数值因此最后还需要softmax一下
with torch.no_grad():
out_data = self.model(in_data)
out_data = F.softmax(out_data, dim=-1)
return PredictResult(out_data)
def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult:
input = torch.Tensor(image).float().to(self.device)
input = torch.Tensor(image).float()
assert(input.dim() == 2)
assert(input.size(0) == 28)
assert(input.size(1) == 28)
# 为了满足要求要在第一维度挤出2下
# 一次是灰度通道,一次是批次。
# 相当于batch size = 1的计算
input = input.unsqueeze(0).unsqueeze(0)
# 预测
with torch.no_grad():
output = self.cnn(input)
return PredictResult(output)
return self.generic_predict(input)
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
# 确保图像为灰度图像然后转换为numpy数组。
# 注意这里的numpy数组是只读的所以要先拷贝一份
grayscale_image = image.convert('L')
numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True)
# 转换到Tensor设置dtype并传到GPU上
data = torch.from_numpy(numpy_data).float().to(self.device)
# 转换到Tensor设置dtype
data = torch.from_numpy(numpy_data).float()
# 归一化到255又因为图像输入是白底黑字需要做转换。
data.div_(255.0).sub_(1).mul_(-1)
# 同理,挤出维度并预测
input = data.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
output = self.cnn(input)
return PredictResult(output)
return self.generic_predict(data)
def main():
predictor = Predictor()
@@ -91,5 +101,6 @@ def main():
if __name__ == "__main__":
pytorch_gpu_utils.print_gpu_availability()
main()