first commit
This commit is contained in:
45
exp2/source/predict.py
Normal file
45
exp2/source/predict.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from train import CNN
|
||||
|
||||
'''
|
||||
python 3.9
|
||||
tensorflow 2.0.0b0
|
||||
pillow(PIL) 4.3.0
|
||||
'''
|
||||
|
||||
|
||||
class Predict(object):
|
||||
def __init__(self):
|
||||
latest = tf.train.latest_checkpoint('./ckpt')
|
||||
self.cnn = CNN()
|
||||
# 恢复网络权重
|
||||
self.cnn.model.load_weights(latest)
|
||||
|
||||
def predict(self, image_path):
|
||||
# 以黑白方式读取图片
|
||||
img = Image.open(image_path).convert('L')
|
||||
img = np.reshape(img, (28, 28, 1)) / 255.
|
||||
x = np.array([1 - img])
|
||||
y = self.cnn.model.predict(x)
|
||||
|
||||
# 因为x只传入了一张图片,取y[0]即可
|
||||
# np.argmax()取得最大值的下标,即代表的数字
|
||||
print(image_path)
|
||||
# print(y[0])
|
||||
print(' -> Predict digit', np.argmax(y[0]))
|
||||
plt.figure("Image") # 图像窗口名称
|
||||
plt.imshow(img)
|
||||
plt.axis('on') # 关掉坐标轴为 off
|
||||
plt.title(np.argmax(y[0])) # 图像题目 # 必须有这个,要不然无法显示
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = Predict()
|
||||
app.predict('./test_images/0.png')
|
||||
app.predict('./test_images/1.png')
|
||||
app.predict('./test_images/4.png')
|
||||
Reference in New Issue
Block a user