1
0

refactor: merge multiple project into one and create new project

This commit is contained in:
2026-04-07 08:30:41 +08:00
parent 7aa7ae3335
commit 6cb1a89751
49 changed files with 2932 additions and 4 deletions

View File

@@ -0,0 +1,40 @@
from pathlib import Path
import tensorflow as tf
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from train import CNN
class Predict(object):
def __init__(self):
latest = tf.train.latest_checkpoint('./ckpt')
self.cnn = CNN()
# 恢复网络权重
self.cnn.model.load_weights(latest)
def predict(self, image_path):
# 以黑白方式读取图片
img = Image.open(image_path).convert('L')
img = np.reshape(img, (28, 28, 1)) / 255.
x = np.array([1 - img])
y = self.cnn.model.predict(x)
# 因为x只传入了一张图片取y[0]即可
# np.argmax()取得最大值的下标,即代表的数字
print(image_path)
# print(y[0])
print(' -> Predict digit', np.argmax(y[0]))
plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('on') # 关掉坐标轴为 off
plt.title(np.argmax(y[0])) # 图像题目 # 必须有这个,要不然无法显示
plt.show()
if __name__ == "__main__":
app = Predict()
images_dir = Path(__file__).resolve().parent.parent / 'test_images'
app.predict(images_dir / '0.png')
app.predict(images_dir / '1.png')
app.predict(images_dir / '4.png')