from pathlib import Path import sys import typing import torch import torchinfo import ignite.engine import ignite.metrics from ignite.engine import Engine, Events from ignite.handlers.tqdm_logger import ProgressBar from dataset import PoetryDataLoader from model import Rnn from predict import generate_random_poetry import settings sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) import gpu_utils class Trainer: """核心训练器""" device: torch.device data_loader: PoetryDataLoader model: Rnn trainer: Engine pbar: ProgressBar def __init__(self): # 创建训练设备,模型和数据加载器。 self.device = gpu_utils.get_gpu_device() self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE) self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device) # 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。 torchinfo.summary(self.model, (settings.N_BATCH_SIZE, settings.POETRY_MAX_LEN), dtypes=[torch.int32,]) # 优化器和损失函数 optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7) criterion = torch.nn.CrossEntropyLoss() # 创建训练器 self.trainer = ignite.engine.create_supervised_trainer( self.model, optimizer, criterion, self.device) # 将训练器关联到进度条 self.pbar = ProgressBar(persist=True) self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss}) # 每次epoch后,作诗一首看看结果 self.trainer.add_event_handler( Events.EPOCH_COMPLETED, lambda: generate_random_poetry(self.data_loader.get_tokenizer(), self.model, ) ) def train_model(self): # 训练模型 self.trainer.run(self.data_loader.loader, max_epochs=settings.N_EPOCH) def save_model(self): # 确保保存模型的文件夹存在。 settings.SAVED_MODEL_PATH.parent.mkdir(parents=True, exist_ok=True) # 仅保存模型参数 torch.save(self.model.state_dict(), settings.SAVED_MODEL_PATH) print(f'Model was saved into: {settings.SAVED_MODEL_PATH}') def main(): trainer = Trainer() trainer.train_model() trainer.save_model() if __name__ == "__main__": gpu_utils.print_gpu_availability() main()