update for change of exp2 and add exp3
This commit is contained in:
35
exp3/source/train.py
Normal file
35
exp3/source/train.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import tensorflow as tf
|
||||
from dataset import PoetryDataGenerator, tokenizer, poetry
|
||||
import settings
|
||||
import utils
|
||||
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Input((None,)),
|
||||
tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),
|
||||
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
|
||||
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
|
||||
tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
|
||||
|
||||
])
|
||||
model.summary()
|
||||
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)
|
||||
|
||||
class Evaluate(tf.keras.callbacks.Callback):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lowest = 1e10
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
if logs['loss'] <= self.lowest:
|
||||
self.lowest = logs['loss']
|
||||
model.save(settings.BEST_MODEL_PATH)
|
||||
print()
|
||||
for i in range(settings.SHOW_NUM):
|
||||
print(utils.generate_random_poetry(tokenizer, model))
|
||||
|
||||
data_generator = PoetryDataGenerator(poetry, random=False)
|
||||
model.fit_generator(data_generator.for_fit(),
|
||||
steps_per_epoch=data_generator.steps,
|
||||
epochs=settings.TRAIN_EPOCHS,
|
||||
callbacks=[Evaluate()])
|
||||
Reference in New Issue
Block a user