1
0
Files
ai-school/exp2/modified/train.py
2025-12-02 23:07:27 +08:00

190 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import typing
import torch
import torchinfo
import ignite.engine
import ignite.metrics
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss
from ignite.handlers.tqdm_logger import ProgressBar
from dataset import MnistDataLoaders
from model import Cnn
import settings
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
# class Trainer:
# device: torch.device
# data_source: MnistDataLoaders
# model: Cnn
# def __init__(self):
# self.device = gpu_utils.get_gpu_device()
# self.data_source = MnistDataLoaders(Trainer.N_BATCH_SIZE)
# self.model = Cnn().to(self.device)
# # 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
# torchinfo.summary(self.model, (Trainer.N_BATCH_SIZE, 1, 28, 28))
# def train(self):
# optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
# # optimizer = torch.optim.AdamW(
# # self.model.parameters(),
# # lr=0.001, # 两者默认学习率都是 0.001
# # betas=(0.9, 0.999), # 两者默认值相同
# # eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon
# # weight_decay=0.0, # 两者默认都是 0
# # amsgrad=False # 两者默认都是 False
# # )
# loss_func = torch.nn.CrossEntropyLoss()
# for epoch in range(Trainer.N_EPOCH):
# self.model.train()
# batch_images: torch.Tensor
# batch_labels: torch.Tensor
# for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_loader):
# gpu_images = batch_images.to(self.device)
# gpu_labels = batch_labels.to(self.device)
# prediction: torch.Tensor = self.model(gpu_images)
# loss: torch.Tensor = loss_func(prediction, gpu_labels)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# if batch_index % 100 == 0:
# literal_loss = loss.item()
# print(f'Epoch: {epoch+1}, Batch: {batch_index}, Loss: {literal_loss:.4f}')
# def save(self):
# file_dir_path = Path(__file__).resolve().parent.parent / 'models'
# file_dir_path.mkdir(parents=True, exist_ok=True)
# file_path = file_dir_path / 'cnn.pth'
# torch.save(self.model.state_dict(), file_path)
# print(f'模型已保存至:{file_path}')
# def test(self):
# self.model.eval()
# correct_sum = 0
# total_sum = 0
# with torch.no_grad():
# batch_images: torch.Tensor
# batch_labels: torch.Tensor
# for batch_images, batch_labels in self.data_source.test_loader:
# gpu_images = batch_images.to(self.device)
# gpu_labels = batch_labels.to(self.device)
# possibilities: torch.Tensor = self.model(gpu_images)
# # 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比
# # 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。
# _, prediction = possibilities.max(1)
# # 返回标签的个数作为这一批的总个数
# total_sum += gpu_labels.size(0)
# correct_sum += prediction.eq(gpu_labels).sum()
# test_acc = 100. * correct_sum / total_sum
# print(f"准确率: {test_acc:.4f}%,共测试了{total_sum}张图片")
# def main():
# trainer = Trainer()
# trainer.train()
# trainer.save()
# trainer.test()
class Trainer:
"""核心训练器"""
device: torch.device
data_source: MnistDataLoaders
model: Cnn
trainer: Engine
trainer_accuracy: Accuracy
evaluator: Engine
pbar: ProgressBar
def __init__(self):
# 创建训练设备,模型和数据加载器。
self.device = gpu_utils.get_gpu_device()
self.model = Cnn().to(self.device)
self.data_source = MnistDataLoaders(batch_size=settings.N_BATCH_SIZE)
# 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
torchinfo.summary(self.model, (settings.N_BATCH_SIZE, 1, 28, 28))
# 优化器和损失函数
optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
criterion = torch.nn.CrossEntropyLoss()
# 创建训练器
self.trainer = ignite.engine.create_supervised_trainer(
self.model, optimizer, criterion, self.device,
# 输出转换为这种形式,因为后面的测量需要用其中一些参数。
# 默认的输出只输出loss。
output_transform=lambda x, y, y_pred, loss: {
'loss': loss.item(),
'y': y,
'y_pred': y_pred
}
)
# 设置训练器测量数据
self.trainer_accuracy = Accuracy(
device=self.device,
# 转换为Accuracy需要的形式。
output_transform=lambda o: (o['y_pred'], o['y'])
)
self.trainer_accuracy.attach(self.trainer, 'accuracy')
# YYC MARK: 这里要手动reset一下不然第一次运行没有accuracy
self.trainer_accuracy.reset()
# 每次epoch前重置accuracy
self.trainer.add_event_handler(
Events.EPOCH_STARTED,
lambda: self.trainer_accuracy.reset()
)
# 将训练器关联到进度条
self.pbar = ProgressBar(persist=True)
self.pbar.attach(self.trainer,
metric_names=['accuracy'],
output_transform=lambda o: {"loss": o['loss']})
# 训练完毕后保存模型
self.trainer.add_event_handler(
Events.COMPLETED,
lambda: self.save_model()
)
# 创建测试的评估器的评估量
evaluator_metrics = {
"accuracy": ignite.metrics.Accuracy(device=self.device),
"loss": ignite.metrics.Loss(criterion, device=self.device)
}
# 创建测试评估器
self.evaluator = ignite.engine.create_supervised_evaluator(
self.model, metrics=evaluator_metrics, device=self.device)
def run(self):
# 训练模型
self.trainer.run(self.data_source.train_loader, max_epochs=settings.N_EPOCH)
# 测试模型并输出结果
self.evaluator.run(self.data_source.test_loader)
metrics = self.evaluator.state.metrics
print(f"Test Done. Accuracy: {metrics['accuracy']:.4f} Loss: {metrics['loss']:.4f}")
def save_model(self):
# 确保保存模型的文件夹存在。
settings.SAVED_MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
# 仅保存模型参数
torch.save(self.model.state_dict(), settings.SAVED_MODEL_PATH)
print(f'Model was saved into: {settings.SAVED_MODEL_PATH}')
def main():
trainer = Trainer()
trainer.run()
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
main()