1
0
Files
ai-school/exp3/modified/predict.py
2025-12-06 19:56:55 +08:00

127 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import numpy
import torch
import torch.nn.functional as F
import settings
from dataset import Tokenizer, PoetryDataLoader
from model import Rnn
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.device, s: str=''):
"""
随机生成一首诗
:param tokenizer: 分词器
:param model: 用于生成古诗的模型
:param s: 用于生成古诗的起始字符串,默认为空串
:return: 一个字符串,表示一首古诗
"""
# 将初始字符串转成token
token_ids = tokenizer.encode(s)
# 去掉结束标记[SEP]
token_ids = token_ids[:-1]
while len(token_ids) < settings.POETRY_MAX_LEN:
# 进行预测其中batch_size=1
input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
output: torch.Tensor = model(input.to(device))
# 计算最后一个字符的概率分布。
# 由于后续预测概率时,需要批次维度,所以方括号里第一项写:保留批次维度。
# 然后因为只有最后一个字符是预测的,其他字符都是辅助推断的,所以方括号第二项-1表示取这个最后一个字符。
# 最后,它的概率分布中不包含[PAD][UNK][CLS]的概率分布所以方括号第三项3:把这些东西删掉这些编号是Tokenizer在编译时写死的详细查看对应模块
possibilities = F.softmax(output[:, -1, 3:])
# 按照预测出的概率,随机选择一个词作为预测结果。
# 如果需要贪心则用argmax替代。
target_index = torch.multinomial(possibilities, num_samples=1)
# 记得把之前删除的维度加回来才是token id
target_id = target_index + 3
# 把target_id加入序列
token_ids.append(target_id)
# 如果target_id是[SEP],表示输出结束,需要退出
if target_id == 3: break
# 解码并返回结果
return tokenizer.decode(token_ids)
def generate_acrostic(tokenizer: Tokenizer, model: Rnn, device: torch.device, head: str):
"""
随机生成一首藏头诗
:param tokenizer: 分词器
:param model: 用于生成古诗的模型
:param head: 藏头诗的头
:return: 一个字符串,表示一首古诗
"""
# 使用空串初始化token_ids
token_ids = tokenizer.encode('')
# 去掉结束标记[SEP],只保留[CLS]
token_ids = token_ids[:-1]
# 标点符号,这里简单的只把逗号和句号作为标点
punctuations = ['', '']
punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
# 缓存生成的诗的list
poetry: list[str] = []
# 对于藏头诗中的每一个字,都生成一个短句
for ch in head:
# 先记录下这个字
poetry.append(ch)
# 将藏头诗的字符转成token id
token_id = tokenizer.token_to_id(ch)
# 加入到列表中去
token_ids.append(token_id)
# 开始生成一个短句
while True:
# 与generate_random_poetry函数相同的方式不断地生成诗句的下一个字。
input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
output: torch.Tensor = model(input.to(device))
possibilities = F.softmax(output[:, -1, 3:])
target_index = torch.multinomial(possibilities, num_samples=1)
target_id = target_index + 3
# 把target_id加入序列
token_ids.append(target_id)
# 只有对应ID不是特殊符号的ID我们才把这个字符推入诗句中
if target_id > 3: poetry.append(tokenizer.id_to_token(target_id))
# 此外,与上面不同的是,当输出为标点符号时,我们退出当前循环,进而生成藏头诗的下一句。
if target_id in punctuation_ids: break
# 解码并返回结果
return ''.join(poetry)
class Predictor:
device: torch.device
data_loader: PoetryDataLoader
model: Rnn
def __init__(self):
self.device = gpu_utils.get_gpu_device()
self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE)
self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device)
# 加载保存好的模型参数
self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH))
def generate_random_poetry(self):
"""随机生成一首诗"""
with torch.no_grad():
generate_random_poetry(self.data_loader.get_tokenizer(),
self.model,
self.device)
def generate_acrostic(self):
"""随机生成一首藏头诗"""
with torch.no_grad():
generate_acrostic(self.data_loader.get_tokenizer(),
self.model,
self.device)