1
0
Files
ai-school/dl-exp/exp2/modified/train.py

86 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import typing
import torch
import torchinfo
import ignite.engine
import ignite.metrics
from ignite.engine import Engine, Events
from ignite.handlers.tqdm_logger import ProgressBar
from dataset import MnistDataLoaders
from model import Cnn
import settings
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
class Trainer:
"""核心训练器"""
device: torch.device
data_source: MnistDataLoaders
model: Cnn
trainer: Engine
evaluator: Engine
pbar: ProgressBar
def __init__(self):
# 创建训练设备,模型和数据加载器。
self.device = gpu_utils.get_gpu_device()
self.model = Cnn().to(self.device)
self.data_source = MnistDataLoaders(batch_size=settings.N_BATCH_SIZE)
# 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
torchinfo.summary(self.model, (settings.N_BATCH_SIZE, 1, 28, 28))
# 优化器和损失函数
optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
criterion = torch.nn.CrossEntropyLoss()
# 创建训练器
self.trainer = ignite.engine.create_supervised_trainer(
self.model, optimizer, criterion, self.device)
# 将训练器关联到进度条
self.pbar = ProgressBar(persist=True)
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
# 创建测试的评估器的评估量
evaluator_metrics = {
# 这个Accuracy要的是logits而不是possibilities
# 所以依然是不需要softmax处理后的结果。
"accuracy": ignite.metrics.Accuracy(device=self.device),
"loss": ignite.metrics.Loss(criterion, device=self.device)
}
# 创建测试评估器
self.evaluator = ignite.engine.create_supervised_evaluator(
self.model, metrics=evaluator_metrics, device=self.device)
def train_model(self):
# 训练模型
self.trainer.run(self.data_source.train_loader, max_epochs=settings.N_EPOCH)
def save_model(self):
# 确保保存模型的文件夹存在。
settings.SAVED_MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
# 仅保存模型参数
torch.save(self.model.state_dict(), settings.SAVED_MODEL_PATH)
print(f'Model was saved into: {settings.SAVED_MODEL_PATH}')
def test_model(self):
# 测试模型并输出结果
self.evaluator.run(self.data_source.test_loader)
metrics = self.evaluator.state.metrics
print(f"Accuracy: {metrics['accuracy']:.4f} Loss: {metrics['loss']:.4f}")
def main():
trainer = Trainer()
trainer.train_model()
trainer.save_model()
trainer.test_model()
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
main()