1
0
Files
ai-school/dl-exp/exp3/modified/train.py

80 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import typing
import torch
import torchinfo
import ignite.engine
import ignite.metrics
from ignite.engine import Engine, Events
from ignite.handlers.tqdm_logger import ProgressBar
from dataset import PoetryDataLoader
from model import Rnn
from predict import generate_random_poetry
import settings
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
class Trainer:
"""核心训练器"""
device: torch.device
data_loader: PoetryDataLoader
model: Rnn
trainer: Engine
pbar: ProgressBar
def __init__(self):
# 创建训练设备,模型和数据加载器。
self.device = gpu_utils.get_gpu_device()
self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE)
self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device)
# 展示模型结构。批次为指定批次数量最大诗歌长度同时输入一定是int32
torchinfo.summary(self.model,
(settings.N_BATCH_SIZE, settings.POETRY_MAX_LEN),
dtypes=[torch.int32,])
# 优化器和损失函数
optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
criterion = torch.nn.CrossEntropyLoss()
# 创建训练器
self.trainer = ignite.engine.create_supervised_trainer(
self.model, optimizer, criterion, self.device,
# 由于PyTorch的交叉熵函数总是要求概率在dim=1所以要调换一下维度才能传入。
model_transform=lambda output: self.__adjust_for_loss(output))
# 将训练器关联到进度条
self.pbar = ProgressBar(persist=True)
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
# 每次epoch后作诗一首看看结果
self.trainer.add_event_handler(
Events.EPOCH_COMPLETED,
lambda: print(generate_random_poetry(self.data_loader.get_tokenizer(), self.model, self.device))
)
def __adjust_for_loss(self, output: torch.Tensor) -> torch.Tensor:
return output.permute(0, 2, 1)
def train_model(self):
# 训练模型
self.trainer.run(self.data_loader.loader, max_epochs=settings.N_EPOCH)
def save_model(self):
# 确保保存模型的文件夹存在。
settings.SAVED_MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
# 仅保存模型参数
torch.save(self.model.state_dict(), settings.SAVED_MODEL_PATH)
print(f'Model was saved into: {settings.SAVED_MODEL_PATH}')
def main():
trainer = Trainer()
trainer.train_model()
trainer.save_model()
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
main()