1
0
Files
ai-school/dl-exp/exp2/source/train.py

57 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import tensorflow as tf
from tensor.keras import datasets, layers, models
class CNN(object):
def __init__(self):
model = models.Sequential()
# 第1层卷积卷积核大小为3*332个28*28为待训练图片的大小
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
# 第2层卷积卷积核大小为3*364个
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
# 第三层卷积卷积核大小为3*364个
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()
self.model = model
class DataSource(object):
def __init__(self):
# mnist数据集存储的位置如何不存在将自动下载
data_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz'
(train_images, train_labels), (test_images,
test_labels) = datasets.mnist.load_data(path=data_path)
# 6万张训练图片1万张测试图片
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
# 像素值映射到 0 - 1 之间
train_images, test_images = train_images / 255.0, test_images / 255.0
self.train_images, self.train_labels = train_images, train_labels
self.test_images, self.test_labels = test_images, test_labels
class Train:
def __init__(self):
self.cnn = CNN()
self.data = DataSource()
def train(self):
check_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.ckpt'
# period 每隔5epoch保存一次
save_model_cb = tf.keras.callbacks.ModelCheckpoint(
str(check_path), save_weights_only=True, verbose=1, period=5)
self.cnn.model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
self.cnn.model.fit(self.data.train_images, self.data.train_labels,
epochs=5, batch_size=1000, callbacks=[save_model_cb])
test_loss, test_acc = self.cnn.model.evaluate(
self.data.test_images, self.data.test_labels)
print("准确率: %.4f, 共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
if __name__ == "__main__":
app = Train()
app.train()