1
0

finish exp3 predict code

This commit is contained in:
2025-12-06 19:56:55 +08:00
parent 45b60b269f
commit ee18246d51
4 changed files with 138 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ from ignite.engine import Engine, Events
from ignite.handlers.tqdm_logger import ProgressBar
from dataset import PoetryDataLoader
from model import Rnn
from predict import generate_random_poetry
import settings
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
@@ -44,6 +45,11 @@ class Trainer:
# 将训练器关联到进度条
self.pbar = ProgressBar(persist=True)
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
# 每次epoch后作诗一首看看结果
self.trainer.add_event_handler(
Events.EPOCH_COMPLETED,
lambda: generate_random_poetry(self.data_loader.get_tokenizer(), self.model, )
)
def train_model(self):
# 训练模型