From 826cd263371fa5500e94de3a7bb4db9977da9b09 Mon Sep 17 00:00:00 2001 From: yyc12345 Date: Tue, 2 Dec 2025 23:12:18 +0800 Subject: [PATCH] remove in-time accuracy display --- exp2/modified/train.py | 54 ++++++++++++------------------------------ 1 file changed, 15 insertions(+), 39 deletions(-) diff --git a/exp2/modified/train.py b/exp2/modified/train.py index 2516504..58d931b 100644 --- a/exp2/modified/train.py +++ b/exp2/modified/train.py @@ -6,7 +6,6 @@ import torchinfo import ignite.engine import ignite.metrics from ignite.engine import Engine, Events -from ignite.metrics import Accuracy, Loss from ignite.handlers.tqdm_logger import ProgressBar from dataset import MnistDataLoaders from model import Cnn @@ -104,7 +103,6 @@ class Trainer: model: Cnn trainer: Engine - trainer_accuracy: Accuracy evaluator: Engine pbar: ProgressBar @@ -121,42 +119,15 @@ class Trainer: criterion = torch.nn.CrossEntropyLoss() # 创建训练器 self.trainer = ignite.engine.create_supervised_trainer( - self.model, optimizer, criterion, self.device, - # 输出转换为这种形式,因为后面的测量需要用其中一些参数。 - # 默认的输出只输出loss。 - output_transform=lambda x, y, y_pred, loss: { - 'loss': loss.item(), - 'y': y, - 'y_pred': y_pred - } - ) - # 设置训练器测量数据 - self.trainer_accuracy = Accuracy( - device=self.device, - # 转换为Accuracy需要的形式。 - output_transform=lambda o: (o['y_pred'], o['y']) - ) - self.trainer_accuracy.attach(self.trainer, 'accuracy') - # YYC MARK: 这里要手动reset一下,不然第一次运行没有accuracy - self.trainer_accuracy.reset() - # 每次epoch前重置accuracy - self.trainer.add_event_handler( - Events.EPOCH_STARTED, - lambda: self.trainer_accuracy.reset() - ) + self.model, optimizer, criterion, self.device) # 将训练器关联到进度条 self.pbar = ProgressBar(persist=True) - self.pbar.attach(self.trainer, - metric_names=['accuracy'], - output_transform=lambda o: {"loss": o['loss']}) - # 训练完毕后保存模型 - self.trainer.add_event_handler( - Events.COMPLETED, - lambda: self.save_model() - ) + self.pbar.attach(self.trainer, output_transform=lambda o: {"loss": o}) # 创建测试的评估器的评估量 evaluator_metrics = { + # 这个Accuracy要的是logits,而不是possibilities, + # 所以依然是不需要softmax处理后的结果。 "accuracy": ignite.metrics.Accuracy(device=self.device), "loss": ignite.metrics.Loss(criterion, device=self.device) } @@ -164,13 +135,9 @@ class Trainer: self.evaluator = ignite.engine.create_supervised_evaluator( self.model, metrics=evaluator_metrics, device=self.device) - def run(self): + def train_model(self): # 训练模型 self.trainer.run(self.data_source.train_loader, max_epochs=settings.N_EPOCH) - # 测试模型并输出结果 - self.evaluator.run(self.data_source.test_loader) - metrics = self.evaluator.state.metrics - print(f"Test Done. Accuracy: {metrics['accuracy']:.4f} Loss: {metrics['loss']:.4f}") def save_model(self): # 确保保存模型的文件夹存在。 @@ -179,9 +146,18 @@ class Trainer: torch.save(self.model.state_dict(), settings.SAVED_MODEL_PATH) print(f'Model was saved into: {settings.SAVED_MODEL_PATH}') + def test_model(self): + # 测试模型并输出结果 + self.evaluator.run(self.data_source.test_loader) + metrics = self.evaluator.state.metrics + print(f"Accuracy: {metrics['accuracy']:.4f} Loss: {metrics['loss']:.4f}") + + def main(): trainer = Trainer() - trainer.run() + trainer.train_model() + trainer.save_model() + trainer.test_model() if __name__ == "__main__":