From ee18246d51429d5453ecc411b40d6047654b8f97 Mon Sep 17 00:00:00 2001 From: yyc12345 Date: Sat, 6 Dec 2025 19:56:55 +0800 Subject: [PATCH] finish exp3 predict code --- exp2/modified/predict.py | 4 +- exp3/modified/dataset.py | 4 ++ exp3/modified/predict.py | 126 +++++++++++++++++++++++++++++++++++++++ exp3/modified/train.py | 6 ++ 4 files changed, 138 insertions(+), 2 deletions(-) diff --git a/exp2/modified/predict.py b/exp2/modified/predict.py index 6a70470..4b1b88d 100644 --- a/exp2/modified/predict.py +++ b/exp2/modified/predict.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from PIL import Image, ImageFile import matplotlib.pyplot as plt from model import Cnn +import settings sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) import gpu_utils @@ -53,8 +54,7 @@ class Predictor: self.model = Cnn().to(self.device) # 加载保存好的模型参数 - file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth' - self.model.load_state_dict(torch.load(file_path)) + self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH)) def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult: """ diff --git a/exp3/modified/dataset.py b/exp3/modified/dataset.py index 4d06d8b..f4ced6d 100644 --- a/exp3/modified/dataset.py +++ b/exp3/modified/dataset.py @@ -233,6 +233,10 @@ class PoetryDataLoader: def get_vocab_size(self) -> int: """一个便捷的获取vocab_size的函数,避免层层调用""" return self.preprocessor.tokenizer.vocab_size + + def get_tokenizer(self) -> Tokenizer: + """一个便捷的获取Tokenizer的函数,避免层层调用""" + return self.preprocessor.tokenizer def __collect_fn(self, batch: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]: """ diff --git a/exp3/modified/predict.py b/exp3/modified/predict.py index e69de29..19120ce 100644 --- a/exp3/modified/predict.py +++ b/exp3/modified/predict.py @@ -0,0 +1,126 @@ +from pathlib import Path +import sys +import numpy +import torch +import torch.nn.functional as F +import settings +from dataset import Tokenizer, PoetryDataLoader +from model import Rnn + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) +import gpu_utils + + +def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.device, s: str=''): + """ + 随机生成一首诗 + + :param tokenizer: 分词器 + :param model: 用于生成古诗的模型 + :param s: 用于生成古诗的起始字符串,默认为空串 + :return: 一个字符串,表示一首古诗 + """ + # 将初始字符串转成token + token_ids = tokenizer.encode(s) + + # 去掉结束标记[SEP] + token_ids = token_ids[:-1] + while len(token_ids) < settings.POETRY_MAX_LEN: + # 进行预测,其中batch_size=1 + input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) + output: torch.Tensor = model(input.to(device)) + # 计算最后一个字符的概率分布。 + # 由于后续预测概率时,需要批次维度,所以方括号里第一项写:保留批次维度。 + # 然后因为只有最后一个字符是预测的,其他字符都是辅助推断的,所以方括号第二项-1表示取这个最后一个字符。 + # 最后,它的概率分布中不包含[PAD][UNK][CLS]的概率分布,所以方括号第三项3:把这些东西删掉(这些编号是Tokenizer在编译时写死的,详细查看对应模块)。 + possibilities = F.softmax(output[:, -1, 3:]) + # 按照预测出的概率,随机选择一个词作为预测结果。 + # 如果需要贪心,则用argmax替代。 + target_index = torch.multinomial(possibilities, num_samples=1) + # 记得把之前删除的维度加回来才是token id + target_id = target_index + 3 + + # 把target_id加入序列 + token_ids.append(target_id) + # 如果target_id是[SEP],表示输出结束,需要退出 + if target_id == 3: break + + # 解码并返回结果 + return tokenizer.decode(token_ids) + + +def generate_acrostic(tokenizer: Tokenizer, model: Rnn, device: torch.device, head: str): + """ + 随机生成一首藏头诗 + + :param tokenizer: 分词器 + :param model: 用于生成古诗的模型 + :param head: 藏头诗的头 + :return: 一个字符串,表示一首古诗 + """ + # 使用空串初始化token_ids + token_ids = tokenizer.encode('') + # 去掉结束标记[SEP],只保留[CLS] + token_ids = token_ids[:-1] + + # 标点符号,这里简单的只把逗号和句号作为标点 + punctuations = [',', '。'] + punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations} + + # 缓存生成的诗的list + poetry: list[str] = [] + # 对于藏头诗中的每一个字,都生成一个短句 + for ch in head: + # 先记录下这个字 + poetry.append(ch) + # 将藏头诗的字符转成token id + token_id = tokenizer.token_to_id(ch) + # 加入到列表中去 + token_ids.append(token_id) + + # 开始生成一个短句 + while True: + # 与generate_random_poetry函数相同的方式,不断地生成诗句的下一个字。 + input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) + output: torch.Tensor = model(input.to(device)) + possibilities = F.softmax(output[:, -1, 3:]) + target_index = torch.multinomial(possibilities, num_samples=1) + target_id = target_index + 3 + + # 把target_id加入序列 + token_ids.append(target_id) + # 只有对应ID不是特殊符号的ID,我们才把这个字符推入诗句中 + if target_id > 3: poetry.append(tokenizer.id_to_token(target_id)) + # 此外,与上面不同的是,当输出为标点符号时,我们退出当前循环,进而生成藏头诗的下一句。 + if target_id in punctuation_ids: break + + # 解码并返回结果 + return ''.join(poetry) + + +class Predictor: + device: torch.device + data_loader: PoetryDataLoader + model: Rnn + + def __init__(self): + self.device = gpu_utils.get_gpu_device() + self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE) + self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device) + + # 加载保存好的模型参数 + self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH)) + + def generate_random_poetry(self): + """随机生成一首诗""" + with torch.no_grad(): + generate_random_poetry(self.data_loader.get_tokenizer(), + self.model, + self.device) + + def generate_acrostic(self): + """随机生成一首藏头诗""" + with torch.no_grad(): + generate_acrostic(self.data_loader.get_tokenizer(), + self.model, + self.device) diff --git a/exp3/modified/train.py b/exp3/modified/train.py index d019dad..2f006f2 100644 --- a/exp3/modified/train.py +++ b/exp3/modified/train.py @@ -9,6 +9,7 @@ from ignite.engine import Engine, Events from ignite.handlers.tqdm_logger import ProgressBar from dataset import PoetryDataLoader from model import Rnn +from predict import generate_random_poetry import settings sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) @@ -44,6 +45,11 @@ class Trainer: # 将训练器关联到进度条 self.pbar = ProgressBar(persist=True) self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss}) + # 每次epoch后,作诗一首看看结果 + self.trainer.add_event_handler( + Events.EPOCH_COMPLETED, + lambda: generate_random_poetry(self.data_loader.get_tokenizer(), self.model, ) + ) def train_model(self): # 训练模型