from pathlib import Path import sys import numpy import torch import torch.nn.functional as F import settings from dataset import Tokenizer, PoetryDataLoader from model import Rnn sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) import gpu_utils def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.device, s: str='') -> str: """ 随机生成一首诗 :param tokenizer: 分词器 :param model: 用于生成古诗的模型 :param s: 用于生成古诗的起始字符串,默认为空串 :return: 一个字符串,表示一首古诗 """ # 将初始字符串转成token token_ids = tokenizer.encode(s) # 去掉结束标记[SEP] token_ids = token_ids[:-1] while len(token_ids) < settings.POETRY_MAX_LEN: # 进行预测,其中batch_size=1 input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) output: torch.Tensor = model(input.to(device)) # 计算最后一个字符的概率分布。 # 由于后续预测概率时,需要批次维度,所以方括号里第一项写:保留批次维度。 # 然后因为只有最后一个字符是预测的,其他字符都是辅助推断的,所以方括号第二项-1表示取这个最后一个字符。 # 最后,它的概率分布中不包含[PAD][UNK][CLS]的概率分布,所以方括号第三项3:把这些东西删掉(这些编号是Tokenizer在编译时写死的,详细查看对应模块)。 possibilities = F.softmax(output[:, -1, 3:], dim=-1) # 按照预测出的概率,随机选择一个词作为预测结果。 # 如果需要贪心,则用argmax替代。 target_index = torch.multinomial(possibilities, num_samples=1) # 记得把之前删除的维度加回来才是token id target_id = target_index.item() + 3 # 把target_id加入序列 token_ids.append(target_id) # 如果target_id是[SEP],表示输出结束,需要退出 if target_id == 3: break # 解码并返回结果 return tokenizer.decode(token_ids) def generate_acrostic(tokenizer: Tokenizer, model: Rnn, device: torch.device, head: str) -> str: """ 随机生成一首藏头诗 :param tokenizer: 分词器 :param model: 用于生成古诗的模型 :param head: 藏头诗的头 :return: 一个字符串,表示一首古诗 """ # 使用空串初始化token_ids token_ids = tokenizer.encode('') # 去掉结束标记[SEP],只保留[CLS] token_ids = token_ids[:-1] # 标点符号,这里简单的只把逗号和句号作为标点 punctuations = [',', '。'] punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations} # 缓存生成的诗的list poetry: list[str] = [] # 对于藏头诗中的每一个字,都生成一个短句 for ch in head: # 先记录下这个字 poetry.append(ch) # 将藏头诗的字符转成token id token_id = tokenizer.token_to_id(ch) # 加入到列表中去 token_ids.append(token_id) # 开始生成一个短句 while True: # 与generate_random_poetry函数相同的方式,不断地生成诗句的下一个字。 input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) output: torch.Tensor = model(input.to(device)) possibilities = F.softmax(output[:, -1, 3:], dim=-1) target_index = torch.multinomial(possibilities, num_samples=1) target_id = target_index.item() + 3 # 把target_id加入序列 token_ids.append(target_id) # 只有对应ID不是特殊符号的ID,我们才把这个字符推入诗句中 if target_id > 3: poetry.append(tokenizer.id_to_token(target_id)) # 此外,与上面不同的是,当输出为标点符号时,我们退出当前循环,进而生成藏头诗的下一句。 if target_id in punctuation_ids: break # 解码并返回结果 return ''.join(poetry) class Predictor: device: torch.device data_loader: PoetryDataLoader model: Rnn def __init__(self): self.device = gpu_utils.get_gpu_device() self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE) self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device) # 加载保存好的模型参数 self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH)) self.model.eval() def generate_random_poetry(self, s: str = ''): """随机生成一首诗""" with torch.no_grad(): print(generate_random_poetry(self.data_loader.get_tokenizer(), self.model, self.device, s)) def generate_acrostic(self, s: str): """随机生成一首藏头诗""" with torch.no_grad(): print(generate_acrostic(self.data_loader.get_tokenizer(), self.model, self.device, s)) def main(): predictor = Predictor() # 随机生成一首诗 predictor.generate_random_poetry() # 给出部分信息的情况下,随机生成剩余部分 predictor.generate_random_poetry('床前明月光,') # 生成藏头诗 predictor.generate_acrostic('好好学习天天向上') if __name__ == "__main__": gpu_utils.print_gpu_availability() main()