1
0

fix exp3 loss function error

This commit is contained in:
2025-12-06 20:48:27 +08:00
parent ee18246d51
commit 7aa7ae3335
3 changed files with 42 additions and 17 deletions

View File

@@ -31,7 +31,7 @@ class Trainer:
self.device = gpu_utils.get_gpu_device()
self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE)
self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device)
# 展示模型结构。批次为指定批次数量,通道只有一个灰度通道大小28x28。
# 展示模型结构。批次为指定批次数量,最大诗歌长度同时输入一定是int32
torchinfo.summary(self.model,
(settings.N_BATCH_SIZE, settings.POETRY_MAX_LEN),
dtypes=[torch.int32,])
@@ -41,16 +41,21 @@ class Trainer:
criterion = torch.nn.CrossEntropyLoss()
# 创建训练器
self.trainer = ignite.engine.create_supervised_trainer(
self.model, optimizer, criterion, self.device)
self.model, optimizer, criterion, self.device,
# 由于PyTorch的交叉熵函数总是要求概率在dim=1所以要调换一下维度才能传入。
model_transform=lambda output: self.__adjust_for_loss(output))
# 将训练器关联到进度条
self.pbar = ProgressBar(persist=True)
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
# 每次epoch后作诗一首看看结果
self.trainer.add_event_handler(
Events.EPOCH_COMPLETED,
lambda: generate_random_poetry(self.data_loader.get_tokenizer(), self.model, )
lambda: print(generate_random_poetry(self.data_loader.get_tokenizer(), self.model, self.device))
)
def __adjust_for_loss(self, output: torch.Tensor) -> torch.Tensor:
return output.permute(0, 2, 1)
def train_model(self):
# 训练模型
self.trainer.run(self.data_loader.loader, max_epochs=settings.N_EPOCH)