from pathlib import Path import tensorflow as tf from PIL import Image import numpy as np import matplotlib.pyplot as plt from train import CNN class Predict(object): def __init__(self): latest = tf.train.latest_checkpoint('./ckpt') self.cnn = CNN() # 恢复网络权重 self.cnn.model.load_weights(latest) def predict(self, image_path): # 以黑白方式读取图片 img = Image.open(image_path).convert('L') img = np.reshape(img, (28, 28, 1)) / 255. x = np.array([1 - img]) y = self.cnn.model.predict(x) # 因为x只传入了一张图片,取y[0]即可 # np.argmax()取得最大值的下标,即代表的数字 print(image_path) # print(y[0]) print(' -> Predict digit', np.argmax(y[0])) plt.figure("Image") # 图像窗口名称 plt.imshow(img) plt.axis('on') # 关掉坐标轴为 off plt.title(np.argmax(y[0])) # 图像题目 # 必须有这个,要不然无法显示 plt.show() if __name__ == "__main__": app = Predict() images_dir = Path(__file__).resolve().parent.parent / 'test_images' app.predict(images_dir / '0.png') app.predict(images_dir / '1.png') app.predict(images_dir / '4.png')