from pathlib import Path import sys import typing import torch import torchinfo import ignite.engine import ignite.metrics from ignite.engine import Engine, Events from ignite.handlers.tqdm_logger import ProgressBar from dataset import PoetryDataLoader from model import Rnn from predict import generate_random_poetry import settings sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) import gpu_utils class Trainer: """核心训练器""" device: torch.device data_loader: PoetryDataLoader model: Rnn trainer: Engine pbar: ProgressBar def __init__(self): # 创建训练设备,模型和数据加载器。 self.device = gpu_utils.get_gpu_device() self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE) self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device) # 展示模型结构。批次为指定批次数量,最大诗歌长度,同时输入一定是int32 torchinfo.summary(self.model, (settings.N_BATCH_SIZE, settings.POETRY_MAX_LEN), dtypes=[torch.int32,]) # 优化器和损失函数 optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7) criterion = torch.nn.CrossEntropyLoss() # 创建训练器 self.trainer = ignite.engine.create_supervised_trainer( self.model, optimizer, criterion, self.device, # 由于PyTorch的交叉熵函数总是要求概率在dim=1,所以要调换一下维度才能传入。 model_transform=lambda output: self.__adjust_for_loss(output)) # 将训练器关联到进度条 self.pbar = ProgressBar(persist=True) self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss}) # 每次epoch后,作诗一首看看结果 self.trainer.add_event_handler( Events.EPOCH_COMPLETED, lambda: print(generate_random_poetry(self.data_loader.get_tokenizer(), self.model, self.device)) ) def __adjust_for_loss(self, output: torch.Tensor) -> torch.Tensor: return output.permute(0, 2, 1) def train_model(self): # 训练模型 self.trainer.run(self.data_loader.loader, max_epochs=settings.N_EPOCH) def save_model(self): # 确保保存模型的文件夹存在。 settings.SAVED_MODEL_PATH.parent.mkdir(parents=True, exist_ok=True) # 仅保存模型参数 torch.save(self.model.state_dict(), settings.SAVED_MODEL_PATH) print(f'Model was saved into: {settings.SAVED_MODEL_PATH}') def main(): trainer = Trainer() trainer.train_model() trainer.save_model() if __name__ == "__main__": gpu_utils.print_gpu_availability() main()