87 lines
3.3 KiB
Python
87 lines
3.3 KiB
Python
|
||
import numpy as np
|
||
import settings
|
||
|
||
|
||
def generate_random_poetry(tokenizer, model, s=''):
|
||
"""
|
||
随机生成一首诗
|
||
:param tokenizer: 分词器
|
||
:param model: 用于生成古诗的模型
|
||
:param s: 用于生成古诗的起始字符串,默认为空串
|
||
:return: 一个字符串,表示一首古诗
|
||
"""
|
||
# 将初始字符串转成token
|
||
token_ids = tokenizer.encode(s)
|
||
# 去掉结束标记[SEP]
|
||
token_ids = token_ids[:-1]
|
||
while len(token_ids) < settings.MAX_LEN:
|
||
# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
|
||
output = model(np.array([token_ids, ], dtype=np.int32))
|
||
_probas = output.numpy()[0, -1, 3:]
|
||
del output
|
||
# print(_probas)
|
||
# 按照出现概率,对所有token倒序排列
|
||
p_args = _probas.argsort()[::-1][:100]
|
||
# 排列后的概率顺序
|
||
p = _probas[p_args]
|
||
# 先对概率归一
|
||
p = p / sum(p)
|
||
# 再按照预测出的概率,随机选择一个词作为预测结果
|
||
target_index = np.random.choice(len(p), p=p)
|
||
target = p_args[target_index] + 3
|
||
# 保存
|
||
token_ids.append(target)
|
||
if target == 3:
|
||
break
|
||
return tokenizer.decode(token_ids)
|
||
|
||
|
||
def generate_acrostic(tokenizer, model, head):
|
||
"""
|
||
随机生成一首藏头诗
|
||
:param tokenizer: 分词器
|
||
:param model: 用于生成古诗的模型
|
||
:param head: 藏头诗的头
|
||
:return: 一个字符串,表示一首古诗
|
||
"""
|
||
# 使用空串初始化token_ids,加入[CLS]
|
||
token_ids = tokenizer.encode('')
|
||
token_ids = token_ids[:-1]
|
||
# 标点符号,这里简单的只把逗号和句号作为标点
|
||
punctuations = [',', '。']
|
||
punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
|
||
# 缓存生成的诗的list
|
||
poetry = []
|
||
# 对于藏头诗中的每一个字,都生成一个短句
|
||
for ch in head:
|
||
# 先记录下这个字
|
||
poetry.append(ch)
|
||
# 将藏头诗的字符转成token id
|
||
token_id = tokenizer.token_to_id(ch)
|
||
# 加入到列表中去
|
||
token_ids.append(token_id)
|
||
# 开始生成一个短句
|
||
while True:
|
||
# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
|
||
output = model(np.array([token_ids, ], dtype=np.int32))
|
||
_probas = output.numpy()[0, -1, 3:]
|
||
del output
|
||
# 按照出现概率,对所有token倒序排列
|
||
p_args = _probas.argsort()[::-1][:100]
|
||
# 排列后的概率顺序
|
||
p = _probas[p_args]
|
||
# 先对概率归一
|
||
p = p / sum(p)
|
||
# 再按照预测出的概率,随机选择一个词作为预测结果
|
||
target_index = np.random.choice(len(p), p=p)
|
||
target = p_args[target_index] + 3
|
||
# 保存
|
||
token_ids.append(target)
|
||
# 只有不是特殊字符时,才保存到poetry里面去
|
||
if target > 3:
|
||
poetry.append(tokenizer.id_to_token(target))
|
||
if target in punctuation_ids:
|
||
break
|
||
return ''.join(poetry)
|