1
0

finish exp2

This commit is contained in:
2025-12-03 09:39:41 +08:00
parent 2b6d6293e2
commit 1061780ea5
2 changed files with 23 additions and 14 deletions

View File

@@ -27,7 +27,11 @@ class PredictResult:
self.possibilities = possibilities
def chosen_number(self) -> int:
"""获取最终选定的数字"""
"""
获取最终选定的数字
:return: 以当前概率分布,推测的最终数字。
"""
# 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比
# 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。
return self.possibilities.argmax(1).item()
@@ -52,11 +56,12 @@ class Predictor:
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
self.model.load_state_dict(torch.load(file_path))
def __predict(self, in_data: torch.Tensor) -> PredictResult:
def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
"""
其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor然后传递给此函数进行实际预测。
:param in_data: 传入的tensor该tensor的shape必须是28x28dtype为float32。
:return: 预测结果。
"""
# 上传tensor到GPU
in_data = in_data.to(self.device)
@@ -75,20 +80,22 @@ class Predictor:
"""
以sketchpad的数据进行预测。
:param image: 该列表的shape必须为28x28
:param image: 该列表的shape必须为28x28
:return: 预测结果。
"""
input = torch.Tensor(image).float()
assert(input.dim() == 2)
assert(input.size(0) == 28)
assert(input.size(1) == 28)
return self.__predict(input)
return self.__predict_tensor(input)
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
"""
以Pillow图像的数据进行预测。
:param image: Pillow图像数据。该图像必须为28x28大小。
:return: 预测结果。
"""
# 确保图像为灰度图像,以及宽高合适
grayscale_image = image.convert('L')
@@ -101,7 +108,7 @@ class Predictor:
# 归一化到255又因为图像输入是白底黑字需要做转换。
data.div_(255.0).sub_(1).mul_(-1)
return self.__predict(data)
return self.__predict_tensor(data)
def main():
predictor = Predictor()

View File

@@ -1,12 +1,14 @@
from pathlib import Path
BATCH_SIZE: int = 16
"""训练的batch size"""
DIRTY_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
"""脏的(未清洗的)古诗数据的路径"""
CLEAN_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
"""干净的(已经清洗过的)古诗数据的路径"""
def get_saved_model_path() -> Path:
"""
获取训练完毕的模型进行保存的路径。
:return: 模型参数保存的路径。
"""
return Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
SAVED_MODULE_PATH: Path = Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
"""训练完毕的模型进行保存的路径"""
N_EPOCH: int = 10
"""训练时的epoch"""
N_BATCH_SIZE: int = 16
"""训练时的batch size"""