fix exp3 loss function error
This commit is contained in:
@@ -31,7 +31,7 @@ class Trainer:
|
||||
self.device = gpu_utils.get_gpu_device()
|
||||
self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE)
|
||||
self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device)
|
||||
# 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。
|
||||
# 展示模型结构。批次为指定批次数量,最大诗歌长度,同时输入一定是int32
|
||||
torchinfo.summary(self.model,
|
||||
(settings.N_BATCH_SIZE, settings.POETRY_MAX_LEN),
|
||||
dtypes=[torch.int32,])
|
||||
@@ -41,16 +41,21 @@ class Trainer:
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
# 创建训练器
|
||||
self.trainer = ignite.engine.create_supervised_trainer(
|
||||
self.model, optimizer, criterion, self.device)
|
||||
self.model, optimizer, criterion, self.device,
|
||||
# 由于PyTorch的交叉熵函数总是要求概率在dim=1,所以要调换一下维度才能传入。
|
||||
model_transform=lambda output: self.__adjust_for_loss(output))
|
||||
# 将训练器关联到进度条
|
||||
self.pbar = ProgressBar(persist=True)
|
||||
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
|
||||
# 每次epoch后,作诗一首看看结果
|
||||
self.trainer.add_event_handler(
|
||||
Events.EPOCH_COMPLETED,
|
||||
lambda: generate_random_poetry(self.data_loader.get_tokenizer(), self.model, )
|
||||
lambda: print(generate_random_poetry(self.data_loader.get_tokenizer(), self.model, self.device))
|
||||
)
|
||||
|
||||
def __adjust_for_loss(self, output: torch.Tensor) -> torch.Tensor:
|
||||
return output.permute(0, 2, 1)
|
||||
|
||||
def train_model(self):
|
||||
# 训练模型
|
||||
self.trainer.run(self.data_loader.loader, max_epochs=settings.N_EPOCH)
|
||||
|
||||
Reference in New Issue
Block a user