1
0
Files
ai-school/exp2/source/predict.py

39 lines
1.2 KiB
Python
Raw Normal View History

2025-11-24 14:20:38 +08:00
import tensorflow as tf
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from train import CNN
class Predict(object):
def __init__(self):
latest = tf.train.latest_checkpoint('./ckpt')
self.cnn = CNN()
# 恢复网络权重
self.cnn.model.load_weights(latest)
def predict(self, image_path):
# 以黑白方式读取图片
img = Image.open(image_path).convert('L')
img = np.reshape(img, (28, 28, 1)) / 255.
x = np.array([1 - img])
y = self.cnn.model.predict(x)
# 因为x只传入了一张图片取y[0]即可
# np.argmax()取得最大值的下标,即代表的数字
print(image_path)
# print(y[0])
print(' -> Predict digit', np.argmax(y[0]))
plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('on') # 关掉坐标轴为 off
plt.title(np.argmax(y[0])) # 图像题目 # 必须有这个,要不然无法显示
plt.show()
if __name__ == "__main__":
app = Predict()
app.predict('./test_images/0.png')
app.predict('./test_images/1.png')
app.predict('./test_images/4.png')