1
0

fix exp3 loss function error

This commit is contained in:
2025-12-06 20:48:27 +08:00
parent ee18246d51
commit 7aa7ae3335
3 changed files with 42 additions and 17 deletions

View File

@@ -264,8 +264,7 @@ class PoetryDataLoader:
# 这么做是为了让RNN从输入推到输出下一个字符
# 此外输出要做onehot编码
input = torch.tensor(numpy_batch[:, :-1], dtype=torch.long)
output = F.one_hot(torch.tensor(numpy_batch[:, 1:], dtype=torch.long),
num_classes=self.preprocessor.tokenizer.vocab_size).float()
output = torch.tensor(numpy_batch[:, 1:], dtype=torch.long)
# 返回结果
return input, output