1
0
Files
ai-school/dl-exp/exp2/source/predict.py

41 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import tensorflow as tf
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from train import CNN
class Predict(object):
def __init__(self):
latest = tf.train.latest_checkpoint('./ckpt')
self.cnn = CNN()
# 恢复网络权重
self.cnn.model.load_weights(latest)
def predict(self, image_path):
# 以黑白方式读取图片
img = Image.open(image_path).convert('L')
img = np.reshape(img, (28, 28, 1)) / 255.
x = np.array([1 - img])
y = self.cnn.model.predict(x)
# 因为x只传入了一张图片取y[0]即可
# np.argmax()取得最大值的下标,即代表的数字
print(image_path)
# print(y[0])
print(' -> Predict digit', np.argmax(y[0]))
plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('on') # 关掉坐标轴为 off
plt.title(np.argmax(y[0])) # 图像题目 # 必须有这个,要不然无法显示
plt.show()
if __name__ == "__main__":
app = Predict()
images_dir = Path(__file__).resolve().parent.parent / 'test_images'
app.predict(images_dir / '0.png')
app.predict(images_dir / '1.png')
app.predict(images_dir / '4.png')