1
0

fix exp3 loss function error

This commit is contained in:
2025-12-06 20:48:27 +08:00
parent ee18246d51
commit 7aa7ae3335
3 changed files with 42 additions and 17 deletions

View File

@@ -11,7 +11,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.device, s: str=''):
def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.device, s: str='') -> str:
"""
随机生成一首诗
@@ -33,12 +33,12 @@ def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.devic
# 由于后续预测概率时,需要批次维度,所以方括号里第一项写:保留批次维度。
# 然后因为只有最后一个字符是预测的,其他字符都是辅助推断的,所以方括号第二项-1表示取这个最后一个字符。
# 最后,它的概率分布中不包含[PAD][UNK][CLS]的概率分布所以方括号第三项3:把这些东西删掉这些编号是Tokenizer在编译时写死的详细查看对应模块
possibilities = F.softmax(output[:, -1, 3:])
possibilities = F.softmax(output[:, -1, 3:], dim=-1)
# 按照预测出的概率,随机选择一个词作为预测结果。
# 如果需要贪心则用argmax替代。
target_index = torch.multinomial(possibilities, num_samples=1)
# 记得把之前删除的维度加回来才是token id
target_id = target_index + 3
target_id = target_index.item() + 3
# 把target_id加入序列
token_ids.append(target_id)
@@ -49,7 +49,7 @@ def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.devic
return tokenizer.decode(token_ids)
def generate_acrostic(tokenizer: Tokenizer, model: Rnn, device: torch.device, head: str):
def generate_acrostic(tokenizer: Tokenizer, model: Rnn, device: torch.device, head: str) -> str:
"""
随机生成一首藏头诗
@@ -83,9 +83,9 @@ def generate_acrostic(tokenizer: Tokenizer, model: Rnn, device: torch.device, he
# 与generate_random_poetry函数相同的方式不断地生成诗句的下一个字。
input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
output: torch.Tensor = model(input.to(device))
possibilities = F.softmax(output[:, -1, 3:])
possibilities = F.softmax(output[:, -1, 3:], dim=-1)
target_index = torch.multinomial(possibilities, num_samples=1)
target_id = target_index + 3
target_id = target_index.item() + 3
# 把target_id加入序列
token_ids.append(target_id)
@@ -110,17 +110,38 @@ class Predictor:
# 加载保存好的模型参数
self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH))
self.model.eval()
def generate_random_poetry(self):
def generate_random_poetry(self, s: str = ''):
"""随机生成一首诗"""
with torch.no_grad():
generate_random_poetry(self.data_loader.get_tokenizer(),
print(generate_random_poetry(self.data_loader.get_tokenizer(),
self.model,
self.device)
self.device,
s))
def generate_acrostic(self):
def generate_acrostic(self, s: str):
"""随机生成一首藏头诗"""
with torch.no_grad():
generate_acrostic(self.data_loader.get_tokenizer(),
print(generate_acrostic(self.data_loader.get_tokenizer(),
self.model,
self.device)
self.device,
s))
def main():
predictor = Predictor()
# 随机生成一首诗
predictor.generate_random_poetry()
# 给出部分信息的情况下,随机生成剩余部分
predictor.generate_random_poetry('床前明月光,')
# 生成藏头诗
predictor.generate_acrostic('好好学习天天向上')
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
main()