fix exp3 loss function error
This commit is contained in:
@@ -264,8 +264,7 @@ class PoetryDataLoader:
|
||||
# 这么做是为了让RNN从输入推到输出(下一个字符)。
|
||||
# 此外,输出要做onehot编码
|
||||
input = torch.tensor(numpy_batch[:, :-1], dtype=torch.long)
|
||||
output = F.one_hot(torch.tensor(numpy_batch[:, 1:], dtype=torch.long),
|
||||
num_classes=self.preprocessor.tokenizer.vocab_size).float()
|
||||
output = torch.tensor(numpy_batch[:, 1:], dtype=torch.long)
|
||||
|
||||
# 返回结果
|
||||
return input, output
|
||||
|
||||
Reference in New Issue
Block a user