1
0

finish exp2

This commit is contained in:
2025-12-03 09:39:41 +08:00
parent 2b6d6293e2
commit 1061780ea5
2 changed files with 23 additions and 14 deletions

View File

@@ -27,7 +27,11 @@ class PredictResult:
self.possibilities = possibilities self.possibilities = possibilities
def chosen_number(self) -> int: def chosen_number(self) -> int:
"""获取最终选定的数字""" """
获取最终选定的数字
:return: 以当前概率分布,推测的最终数字。
"""
# 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比 # 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比
# 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。 # 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。
return self.possibilities.argmax(1).item() return self.possibilities.argmax(1).item()
@@ -52,11 +56,12 @@ class Predictor:
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth' file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
self.model.load_state_dict(torch.load(file_path)) self.model.load_state_dict(torch.load(file_path))
def __predict(self, in_data: torch.Tensor) -> PredictResult: def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
""" """
其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor然后传递给此函数进行实际预测。 其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor然后传递给此函数进行实际预测。
:param in_data: 传入的tensor该tensor的shape必须是28x28dtype为float32。 :param in_data: 传入的tensor该tensor的shape必须是28x28dtype为float32。
:return: 预测结果。
""" """
# 上传tensor到GPU # 上传tensor到GPU
in_data = in_data.to(self.device) in_data = in_data.to(self.device)
@@ -75,20 +80,22 @@ class Predictor:
""" """
以sketchpad的数据进行预测。 以sketchpad的数据进行预测。
:param image: 该列表的shape必须为28x28 :param image: 该列表的shape必须为28x28
:return: 预测结果。
""" """
input = torch.Tensor(image).float() input = torch.Tensor(image).float()
assert(input.dim() == 2) assert(input.dim() == 2)
assert(input.size(0) == 28) assert(input.size(0) == 28)
assert(input.size(1) == 28) assert(input.size(1) == 28)
return self.__predict(input) return self.__predict_tensor(input)
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult: def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
""" """
以Pillow图像的数据进行预测。 以Pillow图像的数据进行预测。
:param image: Pillow图像数据。该图像必须为28x28大小。 :param image: Pillow图像数据。该图像必须为28x28大小。
:return: 预测结果。
""" """
# 确保图像为灰度图像,以及宽高合适 # 确保图像为灰度图像,以及宽高合适
grayscale_image = image.convert('L') grayscale_image = image.convert('L')
@@ -101,7 +108,7 @@ class Predictor:
# 归一化到255又因为图像输入是白底黑字需要做转换。 # 归一化到255又因为图像输入是白底黑字需要做转换。
data.div_(255.0).sub_(1).mul_(-1) data.div_(255.0).sub_(1).mul_(-1)
return self.__predict(data) return self.__predict_tensor(data)
def main(): def main():
predictor = Predictor() predictor = Predictor()

View File

@@ -1,12 +1,14 @@
from pathlib import Path from pathlib import Path
BATCH_SIZE: int = 16 DIRTY_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
"""训练的batch size""" """脏的(未清洗的)古诗数据的路径"""
CLEAN_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
"""干净的(已经清洗过的)古诗数据的路径"""
def get_saved_model_path() -> Path: SAVED_MODULE_PATH: Path = Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
""" """训练完毕的模型进行保存的路径"""
获取训练完毕的模型进行保存的路径。
N_EPOCH: int = 10
:return: 模型参数保存的路径。 """训练时的epoch"""
""" N_BATCH_SIZE: int = 16
return Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth' """训练时的batch size"""