1
0

use ignite for exp2

This commit is contained in:
2025-12-02 23:07:27 +08:00
parent 43b807679f
commit 65c56e938c
15 changed files with 246 additions and 794 deletions

View File

@@ -1,12 +1,6 @@
from pathlib import Path
import sys
import tensorflow as tf
import keras
from keras import datasets, layers, models
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import tensorflow_gpu_util
from tensor.keras import datasets, layers, models
class CNN(object):
def __init__(self):
@@ -46,7 +40,7 @@ class Train:
def train(self):
check_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.ckpt'
# period 每隔5epoch保存一次
save_model_cb = keras.callbacks.ModelCheckpoint(
save_model_cb = tf.keras.callbacks.ModelCheckpoint(
str(check_path), save_weights_only=True, verbose=1, period=5)
self.cnn.model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
@@ -58,6 +52,5 @@ class Train:
print("准确率: %.4f, 共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
if __name__ == "__main__":
tensorflow_gpu_util.print_gpu_availability()
#app = Train()
#app.train()
app = Train()
app.train()