finish exp2
This commit is contained in:
@@ -27,7 +27,11 @@ class PredictResult:
|
|||||||
self.possibilities = possibilities
|
self.possibilities = possibilities
|
||||||
|
|
||||||
def chosen_number(self) -> int:
|
def chosen_number(self) -> int:
|
||||||
"""获取最终选定的数字"""
|
"""
|
||||||
|
获取最终选定的数字
|
||||||
|
|
||||||
|
:return: 以当前概率分布,推测的最终数字。
|
||||||
|
"""
|
||||||
# 输出出来是10个数字各自的可能性,所以要选取最高可能性的那个对比
|
# 输出出来是10个数字各自的可能性,所以要选取最高可能性的那个对比
|
||||||
# 在dim=1上找最大的那个,就选那个。dim=0是批次所以不管他。
|
# 在dim=1上找最大的那个,就选那个。dim=0是批次所以不管他。
|
||||||
return self.possibilities.argmax(1).item()
|
return self.possibilities.argmax(1).item()
|
||||||
@@ -52,11 +56,12 @@ class Predictor:
|
|||||||
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
|
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
|
||||||
self.model.load_state_dict(torch.load(file_path))
|
self.model.load_state_dict(torch.load(file_path))
|
||||||
|
|
||||||
def __predict(self, in_data: torch.Tensor) -> PredictResult:
|
def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
|
||||||
"""
|
"""
|
||||||
其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor,然后传递给此函数进行实际预测。
|
其它预测函数都要使用的预测后端。其它预测函数将数据处理成Tensor,然后传递给此函数进行实际预测。
|
||||||
|
|
||||||
:param in_data: 传入的tensor,该tensor的shape必须是28x28,dtype为float32。
|
:param in_data: 传入的tensor,该tensor的shape必须是28x28,dtype为float32。
|
||||||
|
:return: 预测结果。
|
||||||
"""
|
"""
|
||||||
# 上传tensor到GPU
|
# 上传tensor到GPU
|
||||||
in_data = in_data.to(self.device)
|
in_data = in_data.to(self.device)
|
||||||
@@ -75,20 +80,22 @@ class Predictor:
|
|||||||
"""
|
"""
|
||||||
以sketchpad的数据进行预测。
|
以sketchpad的数据进行预测。
|
||||||
|
|
||||||
:param image: 该列表的shape必须为28x28
|
:param image: 该列表的shape必须为28x28。
|
||||||
|
:return: 预测结果。
|
||||||
"""
|
"""
|
||||||
input = torch.Tensor(image).float()
|
input = torch.Tensor(image).float()
|
||||||
assert(input.dim() == 2)
|
assert(input.dim() == 2)
|
||||||
assert(input.size(0) == 28)
|
assert(input.size(0) == 28)
|
||||||
assert(input.size(1) == 28)
|
assert(input.size(1) == 28)
|
||||||
|
|
||||||
return self.__predict(input)
|
return self.__predict_tensor(input)
|
||||||
|
|
||||||
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
|
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
|
||||||
"""
|
"""
|
||||||
以Pillow图像的数据进行预测。
|
以Pillow图像的数据进行预测。
|
||||||
|
|
||||||
:param image: Pillow图像数据。该图像必须为28x28大小。
|
:param image: Pillow图像数据。该图像必须为28x28大小。
|
||||||
|
:return: 预测结果。
|
||||||
"""
|
"""
|
||||||
# 确保图像为灰度图像,以及宽高合适
|
# 确保图像为灰度图像,以及宽高合适
|
||||||
grayscale_image = image.convert('L')
|
grayscale_image = image.convert('L')
|
||||||
@@ -101,7 +108,7 @@ class Predictor:
|
|||||||
# 归一化到255,又因为图像输入是白底黑字,需要做转换。
|
# 归一化到255,又因为图像输入是白底黑字,需要做转换。
|
||||||
data.div_(255.0).sub_(1).mul_(-1)
|
data.div_(255.0).sub_(1).mul_(-1)
|
||||||
|
|
||||||
return self.__predict(data)
|
return self.__predict_tensor(data)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
predictor = Predictor()
|
predictor = Predictor()
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
BATCH_SIZE: int = 16
|
DIRTY_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
|
||||||
"""训练的batch size"""
|
"""脏的(未清洗的)古诗数据的路径"""
|
||||||
|
CLEAN_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
|
||||||
|
"""干净的(已经清洗过的)古诗数据的路径"""
|
||||||
|
|
||||||
def get_saved_model_path() -> Path:
|
SAVED_MODULE_PATH: Path = Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
|
||||||
"""
|
"""训练完毕的模型进行保存的路径"""
|
||||||
获取训练完毕的模型进行保存的路径。
|
|
||||||
|
N_EPOCH: int = 10
|
||||||
:return: 模型参数保存的路径。
|
"""训练时的epoch"""
|
||||||
"""
|
N_BATCH_SIZE: int = 16
|
||||||
return Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
|
"""训练时的batch size"""
|
||||||
|
|||||||
Reference in New Issue
Block a user