1
0

refactor: merge multiple project into one and create new project

This commit is contained in:
2026-04-07 08:30:41 +08:00
parent 7aa7ae3335
commit 6cb1a89751
49 changed files with 2932 additions and 4 deletions

View File

@@ -0,0 +1,86 @@
import numpy as np
import settings
def generate_random_poetry(tokenizer, model, s=''):
"""
随机生成一首诗
:param tokenizer: 分词器
:param model: 用于生成古诗的模型
:param s: 用于生成古诗的起始字符串,默认为空串
:return: 一个字符串,表示一首古诗
"""
# 将初始字符串转成token
token_ids = tokenizer.encode(s)
# 去掉结束标记[SEP]
token_ids = token_ids[:-1]
while len(token_ids) < settings.MAX_LEN:
# 进行预测只保留第一个样例我们输入的样例数只有1的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
output = model(np.array([token_ids, ], dtype=np.int32))
_probas = output.numpy()[0, -1, 3:]
del output
# print(_probas)
# 按照出现概率对所有token倒序排列
p_args = _probas.argsort()[::-1][:100]
# 排列后的概率顺序
p = _probas[p_args]
# 先对概率归一
p = p / sum(p)
# 再按照预测出的概率,随机选择一个词作为预测结果
target_index = np.random.choice(len(p), p=p)
target = p_args[target_index] + 3
# 保存
token_ids.append(target)
if target == 3:
break
return tokenizer.decode(token_ids)
def generate_acrostic(tokenizer, model, head):
"""
随机生成一首藏头诗
:param tokenizer: 分词器
:param model: 用于生成古诗的模型
:param head: 藏头诗的头
:return: 一个字符串,表示一首古诗
"""
# 使用空串初始化token_ids加入[CLS]
token_ids = tokenizer.encode('')
token_ids = token_ids[:-1]
# 标点符号,这里简单的只把逗号和句号作为标点
punctuations = ['', '']
punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
# 缓存生成的诗的list
poetry = []
# 对于藏头诗中的每一个字,都生成一个短句
for ch in head:
# 先记录下这个字
poetry.append(ch)
# 将藏头诗的字符转成token id
token_id = tokenizer.token_to_id(ch)
# 加入到列表中去
token_ids.append(token_id)
# 开始生成一个短句
while True:
# 进行预测只保留第一个样例我们输入的样例数只有1的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
output = model(np.array([token_ids, ], dtype=np.int32))
_probas = output.numpy()[0, -1, 3:]
del output
# 按照出现概率对所有token倒序排列
p_args = _probas.argsort()[::-1][:100]
# 排列后的概率顺序
p = _probas[p_args]
# 先对概率归一
p = p / sum(p)
# 再按照预测出的概率,随机选择一个词作为预测结果
target_index = np.random.choice(len(p), p=p)
target = p_args[target_index] + 3
# 保存
token_ids.append(target)
# 只有不是特殊字符时才保存到poetry里面去
if target > 3:
poetry.append(tokenizer.id_to_token(target))
if target in punctuation_ids:
break
return ''.join(poetry)