finish exp3 predict code
This commit is contained in:
@@ -9,6 +9,7 @@ from ignite.engine import Engine, Events
|
||||
from ignite.handlers.tqdm_logger import ProgressBar
|
||||
from dataset import PoetryDataLoader
|
||||
from model import Rnn
|
||||
from predict import generate_random_poetry
|
||||
import settings
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||
@@ -44,6 +45,11 @@ class Trainer:
|
||||
# 将训练器关联到进度条
|
||||
self.pbar = ProgressBar(persist=True)
|
||||
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
|
||||
# 每次epoch后,作诗一首看看结果
|
||||
self.trainer.add_event_handler(
|
||||
Events.EPOCH_COMPLETED,
|
||||
lambda: generate_random_poetry(self.data_loader.get_tokenizer(), self.model, )
|
||||
)
|
||||
|
||||
def train_model(self):
|
||||
# 训练模型
|
||||
|
||||
Reference in New Issue
Block a user