1
0
Files
ai-school/exp2/modified/train.py

193 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import typing
import torch
import torchinfo
import ignite.engine
import ignite.metrics
from ignite.engine import Engine, Events
from ignite.handlers.tqdm_logger import ProgressBar
from mnist import CNN, MnistDataSource
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
class Trainer:
N_EPOCH: typing.ClassVar[int] = 5
N_BATCH_SIZE: typing.ClassVar[int] = 1000
device: torch.device
data_source: MnistDataSource
model: CNN
def __init__(self):
self.device = gpu_utils.get_gpu_device()
self.data_source = MnistDataSource(Trainer.N_BATCH_SIZE)
self.model = CNN().to(self.device)
# 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
torchinfo.summary(self.model, (Trainer.N_BATCH_SIZE, 1, 28, 28))
def train(self):
optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
# optimizer = torch.optim.AdamW(
# self.model.parameters(),
# lr=0.001, # 两者默认学习率都是 0.001
# betas=(0.9, 0.999), # 两者默认值相同
# eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon
# weight_decay=0.0, # 两者默认都是 0
# amsgrad=False # 两者默认都是 False
# )
loss_func = torch.nn.CrossEntropyLoss()
for epoch in range(Trainer.N_EPOCH):
self.model.train()
batch_images: torch.Tensor
batch_labels: torch.Tensor
for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_loader):
gpu_images = batch_images.to(self.device)
gpu_labels = batch_labels.to(self.device)
prediction: torch.Tensor = self.model(gpu_images)
loss: torch.Tensor = loss_func(prediction, gpu_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_index % 100 == 0:
literal_loss = loss.item()
print(f'Epoch: {epoch+1}, Batch: {batch_index}, Loss: {literal_loss:.4f}')
def save(self):
file_dir_path = Path(__file__).resolve().parent.parent / 'models'
file_dir_path.mkdir(parents=True, exist_ok=True)
file_path = file_dir_path / 'cnn.pth'
torch.save(self.model.state_dict(), file_path)
print(f'模型已保存至:{file_path}')
def test(self):
self.model.eval()
correct_sum = 0
total_sum = 0
with torch.no_grad():
for batch_images, batch_labels in self.data_source.test_loader:
gpu_images = batch_images.to(self.device)
gpu_labels = batch_labels.to(self.device)
possibilities: torch.Tensor = self.model(gpu_images)
# 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比
# 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。
_, prediction = possibilities.max(1)
# 返回标签的个数作为这一批的总个数
total_sum += gpu_labels.size(0)
correct_sum += prediction.eq(gpu_labels).sum()
test_acc = 100. * correct_sum / total_sum
print(f"准确率: {test_acc:.4f}%,共测试了{total_sum}张图片")
def main():
trainer = Trainer()
trainer.train()
trainer.save()
trainer.test()
# N_EPOCH: int = 5
# N_BATCH_SIZE: int = 1000
# N_LOG_INTERVAL: int = 10
# class Trainer:
# device: torch.device
# data_source: MnistDataSource
# model: CNN
# trainer: Engine
# train_evaluator: Engine
# test_evaluator: Engine
# def __init__(self):
# self.device = gpu_utils.get_gpu_device()
# self.model = CNN().to(self.device)
# self.data_source = MnistDataSource(batch_size=N_BATCH_SIZE)
# # 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
# torchinfo.summary(self.model, (N_BATCH_SIZE, 1, 28, 28))
# #optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
# optimizer = torch.optim.AdamW(
# self.model.parameters(),
# lr=0.001, # 两者默认学习率都是 0.001
# betas=(0.9, 0.999), # 两者默认值相同
# eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon
# weight_decay=0.0, # 两者默认都是 0
# amsgrad=False # 两者默认都是 False
# )
# criterion = torch.nn.CrossEntropyLoss()
# self.trainer = ignite.engine.create_supervised_trainer(
# self.model, optimizer, criterion, self.device
# )
# eval_metrics = {
# "accuracy": ignite.metrics.Accuracy(device=self.device),
# "loss": ignite.metrics.Loss(criterion, device=self.device)
# }
# self.train_evaluator = ignite.engine.create_supervised_evaluator(
# self.model, metrics=eval_metrics, device=self.device)
# self.test_evaluator = ignite.engine.create_supervised_evaluator(
# self.model, metrics=eval_metrics, device=self.device)
# self.trainer.add_event_handler(
# Events.ITERATION_COMPLETED(every=N_LOG_INTERVAL),
# lambda engine: self.log_intrain_loss(engine)
# )
# self.trainer.add_event_handler(
# Events.EPOCH_COMPLETED,
# lambda trainer: self.log_train_results(trainer)
# )
# self.trainer.add_event_handler(
# Events.COMPLETED,
# lambda _: self.log_test_results()
# )
# self.trainer.add_event_handler(
# Events.COMPLETED,
# lambda _: self.save_model()
# )
# progressbar = ProgressBar()
# progressbar.attach(self.trainer)
# def run(self):
# self.trainer.run(self.data_source.train_loader, max_epochs=N_EPOCH)
# def log_intrain_loss(self, engine: Engine):
# print(f"Epoch: {engine.state.epoch}, Loss: {engine.state.output:.4f}\r", end="")
# def log_train_results(self, trainer: Engine):
# self.train_evaluator.run(self.data_source.train_loader)
# metrics = self.train_evaluator.state.metrics
# print()
# print(f"Training - Epoch: {trainer.state.epoch}, Avg Accuracy: {metrics['accuracy']:.4f}, Avg Loss: {metrics['loss']:.4f}")
# def log_test_results(self):
# self.test_evaluator.run(self.data_source.test_loader)
# metrics = self.test_evaluator.state.metrics
# print(f"Test - Avg Accuracy: {metrics['accuracy']:.4f} Avg Loss: {metrics['loss']:.4f}")
# def save_model(self):
# file_dir_path = Path(__file__).resolve().parent.parent / 'models'
# file_dir_path.mkdir(parents=True, exist_ok=True)
# file_path = file_dir_path / 'cnn.pth'
# torch.save(self.model.state_dict(), file_path)
# print(f'Model was saved into: {file_path}')
# def main():
# trainer = Trainer()
# trainer.run()
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
main()