36 lines
1.3 KiB
Python
36 lines
1.3 KiB
Python
import tensorflow as tf
|
|
from dataset import PoetryDataGenerator, tokenizer, poetry
|
|
import settings
|
|
import utils
|
|
|
|
model = tf.keras.Sequential([
|
|
tf.keras.layers.Input((None,)),
|
|
tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),
|
|
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
|
|
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
|
|
tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
|
|
|
|
])
|
|
model.summary()
|
|
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)
|
|
|
|
class Evaluate(tf.keras.callbacks.Callback):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lowest = 1e10
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
if logs['loss'] <= self.lowest:
|
|
self.lowest = logs['loss']
|
|
model.save(settings.BEST_MODEL_PATH)
|
|
print()
|
|
for i in range(settings.SHOW_NUM):
|
|
print(utils.generate_random_poetry(tokenizer, model))
|
|
|
|
data_generator = PoetryDataGenerator(poetry, random=False)
|
|
model.fit_generator(data_generator.for_fit(),
|
|
steps_per_epoch=data_generator.steps,
|
|
epochs=settings.TRAIN_EPOCHS,
|
|
callbacks=[Evaluate()])
|