1
0
Files
ai-school/dl-exp/exp3/modified/predict.py

148 lines
5.5 KiB
Python
Raw Normal View History

2025-12-06 19:56:55 +08:00
from pathlib import Path
import sys
import numpy
import torch
import torch.nn.functional as F
import settings
from dataset import Tokenizer, PoetryDataLoader
from model import Rnn
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
2025-12-06 20:48:27 +08:00
def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.device, s: str='') -> str:
2025-12-06 19:56:55 +08:00
"""
随机生成一首诗
:param tokenizer: 分词器
:param model: 用于生成古诗的模型
:param s: 用于生成古诗的起始字符串默认为空串
:return: 一个字符串表示一首古诗
"""
# 将初始字符串转成token
token_ids = tokenizer.encode(s)
# 去掉结束标记[SEP]
token_ids = token_ids[:-1]
while len(token_ids) < settings.POETRY_MAX_LEN:
# 进行预测其中batch_size=1
input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
output: torch.Tensor = model(input.to(device))
# 计算最后一个字符的概率分布。
# 由于后续预测概率时,需要批次维度,所以方括号里第一项写:保留批次维度。
# 然后因为只有最后一个字符是预测的,其他字符都是辅助推断的,所以方括号第二项-1表示取这个最后一个字符。
# 最后,它的概率分布中不包含[PAD][UNK][CLS]的概率分布所以方括号第三项3:把这些东西删掉这些编号是Tokenizer在编译时写死的详细查看对应模块
2025-12-06 20:48:27 +08:00
possibilities = F.softmax(output[:, -1, 3:], dim=-1)
2025-12-06 19:56:55 +08:00
# 按照预测出的概率,随机选择一个词作为预测结果。
# 如果需要贪心则用argmax替代。
target_index = torch.multinomial(possibilities, num_samples=1)
# 记得把之前删除的维度加回来才是token id
2025-12-06 20:48:27 +08:00
target_id = target_index.item() + 3
2025-12-06 19:56:55 +08:00
# 把target_id加入序列
token_ids.append(target_id)
# 如果target_id是[SEP],表示输出结束,需要退出
if target_id == 3: break
# 解码并返回结果
return tokenizer.decode(token_ids)
2025-12-06 20:48:27 +08:00
def generate_acrostic(tokenizer: Tokenizer, model: Rnn, device: torch.device, head: str) -> str:
2025-12-06 19:56:55 +08:00
"""
随机生成一首藏头诗
:param tokenizer: 分词器
:param model: 用于生成古诗的模型
:param head: 藏头诗的头
:return: 一个字符串表示一首古诗
"""
# 使用空串初始化token_ids
token_ids = tokenizer.encode('')
# 去掉结束标记[SEP],只保留[CLS]
token_ids = token_ids[:-1]
# 标点符号,这里简单的只把逗号和句号作为标点
punctuations = ['', '']
punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
# 缓存生成的诗的list
poetry: list[str] = []
# 对于藏头诗中的每一个字,都生成一个短句
for ch in head:
# 先记录下这个字
poetry.append(ch)
# 将藏头诗的字符转成token id
token_id = tokenizer.token_to_id(ch)
# 加入到列表中去
token_ids.append(token_id)
# 开始生成一个短句
while True:
# 与generate_random_poetry函数相同的方式不断地生成诗句的下一个字。
input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
output: torch.Tensor = model(input.to(device))
2025-12-06 20:48:27 +08:00
possibilities = F.softmax(output[:, -1, 3:], dim=-1)
2025-12-06 19:56:55 +08:00
target_index = torch.multinomial(possibilities, num_samples=1)
2025-12-06 20:48:27 +08:00
target_id = target_index.item() + 3
2025-12-06 19:56:55 +08:00
# 把target_id加入序列
token_ids.append(target_id)
# 只有对应ID不是特殊符号的ID我们才把这个字符推入诗句中
if target_id > 3: poetry.append(tokenizer.id_to_token(target_id))
# 此外,与上面不同的是,当输出为标点符号时,我们退出当前循环,进而生成藏头诗的下一句。
if target_id in punctuation_ids: break
# 解码并返回结果
return ''.join(poetry)
class Predictor:
device: torch.device
data_loader: PoetryDataLoader
model: Rnn
def __init__(self):
self.device = gpu_utils.get_gpu_device()
self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE)
self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device)
# 加载保存好的模型参数
self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH))
2025-12-06 20:48:27 +08:00
self.model.eval()
2025-12-06 19:56:55 +08:00
2025-12-06 20:48:27 +08:00
def generate_random_poetry(self, s: str = ''):
2025-12-06 19:56:55 +08:00
"""随机生成一首诗"""
with torch.no_grad():
2025-12-06 20:48:27 +08:00
print(generate_random_poetry(self.data_loader.get_tokenizer(),
2025-12-06 19:56:55 +08:00
self.model,
2025-12-06 20:48:27 +08:00
self.device,
s))
2025-12-06 19:56:55 +08:00
2025-12-06 20:48:27 +08:00
def generate_acrostic(self, s: str):
2025-12-06 19:56:55 +08:00
"""随机生成一首藏头诗"""
with torch.no_grad():
2025-12-06 20:48:27 +08:00
print(generate_acrostic(self.data_loader.get_tokenizer(),
2025-12-06 19:56:55 +08:00
self.model,
2025-12-06 20:48:27 +08:00
self.device,
s))
def main():
predictor = Predictor()
# 随机生成一首诗
predictor.generate_random_poetry()
# 给出部分信息的情况下,随机生成剩余部分
predictor.generate_random_poetry('床前明月光,')
# 生成藏头诗
predictor.generate_acrostic('好好学习天天向上')
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
main()