41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
from pathlib import Path
|
||
import tensorflow as tf
|
||
from PIL import Image
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
|
||
from train import CNN
|
||
|
||
class Predict(object):
|
||
def __init__(self):
|
||
latest = tf.train.latest_checkpoint('./ckpt')
|
||
self.cnn = CNN()
|
||
# 恢复网络权重
|
||
self.cnn.model.load_weights(latest)
|
||
|
||
def predict(self, image_path):
|
||
# 以黑白方式读取图片
|
||
img = Image.open(image_path).convert('L')
|
||
img = np.reshape(img, (28, 28, 1)) / 255.
|
||
x = np.array([1 - img])
|
||
y = self.cnn.model.predict(x)
|
||
|
||
# 因为x只传入了一张图片,取y[0]即可
|
||
# np.argmax()取得最大值的下标,即代表的数字
|
||
print(image_path)
|
||
# print(y[0])
|
||
print(' -> Predict digit', np.argmax(y[0]))
|
||
plt.figure("Image") # 图像窗口名称
|
||
plt.imshow(img)
|
||
plt.axis('on') # 关掉坐标轴为 off
|
||
plt.title(np.argmax(y[0])) # 图像题目 # 必须有这个,要不然无法显示
|
||
plt.show()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app = Predict()
|
||
images_dir = Path(__file__).resolve().parent.parent / 'test_images'
|
||
app.predict(images_dir / '0.png')
|
||
app.predict(images_dir / '1.png')
|
||
app.predict(images_dir / '4.png')
|