diff --git a/exp2/modified/predict.py b/exp2/modified/predict.py index a56ecc3..e730bd2 100644 --- a/exp2/modified/predict.py +++ b/exp2/modified/predict.py @@ -27,7 +27,11 @@ class PredictResult: self.possibilities = possibilities def chosen_number(self) -> int: - """获取最终选定的数字""" + """ + 获取最终选定的数字 + + :return: 以当前概率分布,推测的最终数字。 + """ # 输出出来是10个数字各自的可能性,所以要选取最高可能性的那个对比 # 在dim=1上找最大的那个,就选那个。dim=0是批次所以不管他。 return self.possibilities.argmax(1).item() @@ -52,11 +56,12 @@ class Predictor: file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth' self.model.load_state_dict(torch.load(file_path)) - def __predict(self, in_data: torch.Tensor) -> PredictResult: + def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult: """ 其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor,然后传递给此函数进行实际预测。 :param in_data: 传入的tensor,该tensor的shape必须是28x28,dtype为float32。 + :return: 预测结果。 """ # 上传tensor到GPU in_data = in_data.to(self.device) @@ -75,20 +80,22 @@ class Predictor: """ 以sketchpad的数据进行预测。 - :param image: 该列表的shape必须为28x28 + :param image: 该列表的shape必须为28x28。 + :return: 预测结果。 """ input = torch.Tensor(image).float() assert(input.dim() == 2) assert(input.size(0) == 28) assert(input.size(1) == 28) - return self.__predict(input) + return self.__predict_tensor(input) def predict_image(self, image: ImageFile.ImageFile) -> PredictResult: """ 以Pillow图像的数据进行预测。 :param image: Pillow图像数据。该图像必须为28x28大小。 + :return: 预测结果。 """ # 确保图像为灰度图像,以及宽高合适 grayscale_image = image.convert('L') @@ -101,7 +108,7 @@ class Predictor: # 归一化到255,又因为图像输入是白底黑字,需要做转换。 data.div_(255.0).sub_(1).mul_(-1) - return self.__predict(data) + return self.__predict_tensor(data) def main(): predictor = Predictor() diff --git a/exp3/modified/settings.py b/exp3/modified/settings.py index d0a692f..c7d3c68 100644 --- a/exp3/modified/settings.py +++ b/exp3/modified/settings.py @@ -1,12 +1,14 @@ from pathlib import Path -BATCH_SIZE: int = 16 -"""训练的batch size""" +DIRTY_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt' +"""脏的(未清洗的)古诗数据的路径""" +CLEAN_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle' +"""干净的(已经清洗过的)古诗数据的路径""" -def get_saved_model_path() -> Path: - """ - 获取训练完毕的模型进行保存的路径。 - - :return: 模型参数保存的路径。 - """ - return Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth' +SAVED_MODULE_PATH: Path = Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth' +"""训练完毕的模型进行保存的路径""" + +N_EPOCH: int = 10 +"""训练时的epoch""" +N_BATCH_SIZE: int = 16 +"""训练时的batch size"""