use ignite for exp2
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from pathlib import Path
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
@@ -33,6 +34,7 @@ class Predict(object):
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = Predict()
|
||||
app.predict('./test_images/0.png')
|
||||
app.predict('./test_images/1.png')
|
||||
app.predict('./test_images/4.png')
|
||||
images_dir = Path(__file__).resolve().parent.parent / 'test_images'
|
||||
app.predict(images_dir / '0.png')
|
||||
app.predict(images_dir / '1.png')
|
||||
app.predict(images_dir / '4.png')
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
import keras
|
||||
from keras import datasets, layers, models
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||
import tensorflow_gpu_util
|
||||
|
||||
from tensor.keras import datasets, layers, models
|
||||
|
||||
class CNN(object):
|
||||
def __init__(self):
|
||||
@@ -46,7 +40,7 @@ class Train:
|
||||
def train(self):
|
||||
check_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.ckpt'
|
||||
# period 每隔5epoch保存一次
|
||||
save_model_cb = keras.callbacks.ModelCheckpoint(
|
||||
save_model_cb = tf.keras.callbacks.ModelCheckpoint(
|
||||
str(check_path), save_weights_only=True, verbose=1, period=5)
|
||||
self.cnn.model.compile(optimizer='adam',
|
||||
loss='sparse_categorical_crossentropy',
|
||||
@@ -58,6 +52,5 @@ class Train:
|
||||
print("准确率: %.4f, 共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
tensorflow_gpu_util.print_gpu_availability()
|
||||
#app = Train()
|
||||
#app.train()
|
||||
app = Train()
|
||||
app.train()
|
||||
|
||||
Reference in New Issue
Block a user