finish exp2
This commit is contained in:
@@ -27,7 +27,11 @@ class PredictResult:
|
||||
self.possibilities = possibilities
|
||||
|
||||
def chosen_number(self) -> int:
|
||||
"""获取最终选定的数字"""
|
||||
"""
|
||||
获取最终选定的数字
|
||||
|
||||
:return: 以当前概率分布,推测的最终数字。
|
||||
"""
|
||||
# 输出出来是10个数字各自的可能性,所以要选取最高可能性的那个对比
|
||||
# 在dim=1上找最大的那个,就选那个。dim=0是批次所以不管他。
|
||||
return self.possibilities.argmax(1).item()
|
||||
@@ -52,11 +56,12 @@ class Predictor:
|
||||
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
|
||||
self.model.load_state_dict(torch.load(file_path))
|
||||
|
||||
def __predict(self, in_data: torch.Tensor) -> PredictResult:
|
||||
def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
|
||||
"""
|
||||
其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor,然后传递给此函数进行实际预测。
|
||||
|
||||
:param in_data: 传入的tensor,该tensor的shape必须是28x28,dtype为float32。
|
||||
:return: 预测结果。
|
||||
"""
|
||||
# 上传tensor到GPU
|
||||
in_data = in_data.to(self.device)
|
||||
@@ -75,20 +80,22 @@ class Predictor:
|
||||
"""
|
||||
以sketchpad的数据进行预测。
|
||||
|
||||
:param image: 该列表的shape必须为28x28
|
||||
:param image: 该列表的shape必须为28x28。
|
||||
:return: 预测结果。
|
||||
"""
|
||||
input = torch.Tensor(image).float()
|
||||
assert(input.dim() == 2)
|
||||
assert(input.size(0) == 28)
|
||||
assert(input.size(1) == 28)
|
||||
|
||||
return self.__predict(input)
|
||||
return self.__predict_tensor(input)
|
||||
|
||||
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
|
||||
"""
|
||||
以Pillow图像的数据进行预测。
|
||||
|
||||
:param image: Pillow图像数据。该图像必须为28x28大小。
|
||||
:return: 预测结果。
|
||||
"""
|
||||
# 确保图像为灰度图像,以及宽高合适
|
||||
grayscale_image = image.convert('L')
|
||||
@@ -101,7 +108,7 @@ class Predictor:
|
||||
# 归一化到255,又因为图像输入是白底黑字,需要做转换。
|
||||
data.div_(255.0).sub_(1).mul_(-1)
|
||||
|
||||
return self.__predict(data)
|
||||
return self.__predict_tensor(data)
|
||||
|
||||
def main():
|
||||
predictor = Predictor()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
BATCH_SIZE: int = 16
|
||||
"""训练的batch size"""
|
||||
DIRTY_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
|
||||
"""脏的(未清洗的)古诗数据的路径"""
|
||||
CLEAN_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
|
||||
"""干净的(已经清洗过的)古诗数据的路径"""
|
||||
|
||||
def get_saved_model_path() -> Path:
|
||||
"""
|
||||
获取训练完毕的模型进行保存的路径。
|
||||
SAVED_MODULE_PATH: Path = Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
|
||||
"""训练完毕的模型进行保存的路径"""
|
||||
|
||||
:return: 模型参数保存的路径。
|
||||
"""
|
||||
return Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
|
||||
N_EPOCH: int = 10
|
||||
"""训练时的epoch"""
|
||||
N_BATCH_SIZE: int = 16
|
||||
"""训练时的batch size"""
|
||||
|
||||
Reference in New Issue
Block a user