1
0

use ignite for exp2

This commit is contained in:
2025-12-02 23:07:27 +08:00
parent 43b807679f
commit 65c56e938c
15 changed files with 246 additions and 794 deletions

View File

@@ -8,26 +8,36 @@ import matplotlib.pyplot as plt
from model import Cnn
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import pytorch_gpu_utils
import gpu_utils
class PredictResult:
"""预测的结果"""
possibilities: torch.Tensor
"""预测结果,是每个数字不同的概率是经过softmax后的数值"""
"""每个数字不同的概率"""
def __init__(self, possibilities: torch.Tensor):
"""
创建预测结果。
:param possibilities: 传入的tensor表示每个数字不同的概率是经过softmax后的数值。
其shape为二维。dim 0为batch应当只有一维dim 1为每个数字对应的概率。
"""
self.possibilities = possibilities
def chosen_number(self) -> int:
"""获取最终选定的数字"""
# 依然是找最大的那个index
_, prediction = self.possibilities.max(1)
return prediction.item()
# 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比
# 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。
return self.possibilities.argmax(1).item()
def number_possibilities(self) -> list[float]:
"""获取每个数字出现的概率"""
"""
获取每个数字出现的概率
:return: 返回一个具有10个元素的列表列表的每一项表示当前index所代表数字的概率。
"""
return list(self.possibilities[0][i].item() for i in range(10))
class Predictor:
@@ -35,14 +45,14 @@ class Predictor:
model: Cnn
def __init__(self):
self.device = pytorch_gpu_utils.get_gpu_device()
self.device = gpu_utils.get_gpu_device()
self.model = Cnn().to(self.device)
# 加载保存好的模型参数
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
self.model.load_state_dict(torch.load(file_path))
def generic_predict(self, in_data: torch.Tensor) -> PredictResult:
def __predict(self, in_data: torch.Tensor) -> PredictResult:
"""
其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor然后传递给此函数进行实际预测。
@@ -62,24 +72,36 @@ class Predictor:
def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult:
"""
以sketchpad的数据进行预测。
:param image: 该列表的shape必须为28x28
"""
input = torch.Tensor(image).float()
assert(input.dim() == 2)
assert(input.size(0) == 28)
assert(input.size(1) == 28)
return self.generic_predict(input)
return self.__predict(input)
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
# 确保图像为灰度图像然后转换为numpy数组。
# 注意这里的numpy数组是只读的所以要先拷贝一份
"""
以Pillow图像的数据进行预测。
:param image: Pillow图像数据。该图像必须为28x28大小。
"""
# 确保图像为灰度图像,以及宽高合适
grayscale_image = image.convert('L')
assert(grayscale_image.width == 28)
assert(grayscale_image.height == 28)
# 转换为numpy数组。注意这里的numpy数组是只读的所以要先拷贝一份
numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True)
# 转换到Tensor设置dtype
data = torch.from_numpy(numpy_data).float()
# 归一化到255又因为图像输入是白底黑字需要做转换。
data.div_(255.0).sub_(1).mul_(-1)
return self.generic_predict(data)
return self.__predict(data)
def main():
predictor = Predictor()
@@ -101,6 +123,6 @@ def main():
if __name__ == "__main__":
pytorch_gpu_utils.print_gpu_availability()
gpu_utils.print_gpu_availability()
main()