1
0

fix exp2 pytorch rewrite fatal train issue

This commit is contained in:
2025-11-30 22:01:56 +08:00
parent 48fcdfcc80
commit 43b807679f
13 changed files with 738 additions and 112 deletions

View File

@@ -1,11 +1,13 @@
import os
from pathlib import Path
import sys
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import keras
from keras import datasets, layers, models
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import tensorflow_gpu_util
'''
python 3.9
tensorflow 2.0.0b0
'''
class CNN(object):
def __init__(self):
model = models.Sequential()
@@ -26,8 +28,7 @@ class CNN(object):
class DataSource(object):
def __init__(self):
# mnist数据集存储的位置如何不存在将自动下载
data_path = os.path.abspath(os.path.dirname(
__file__)) + '.'
data_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz'
(train_images, train_labels), (test_images,
test_labels) = datasets.mnist.load_data(path=data_path)
# 6万张训练图片1万张测试图片
@@ -43,19 +44,20 @@ class Train:
self.cnn = CNN()
self.data = DataSource()
def train(self):
check_path = './ckpt/cp-{epoch:04d}.ckpt'
check_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.ckpt'
# period 每隔5epoch保存一次
save_model_cb = tf.keras.callbacks.ModelCheckpoint(
check_path, save_weights_only=True, verbose=1, period=5)
save_model_cb = keras.callbacks.ModelCheckpoint(
str(check_path), save_weights_only=True, verbose=1, period=5)
self.cnn.model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
self.cnn.model.fit(self.data.train_images, self.data.train_labels,
epochs=5, callbacks=[save_model_cb])
epochs=5, batch_size=1000, callbacks=[save_model_cb])
test_loss, test_acc = self.cnn.model.evaluate(
self.data.test_images, self.data.test_labels)
print("准确率: %.4f共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
print("准确率: %.4f, 共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
if __name__ == "__main__":
app = Train()
app.train()
tensorflow_gpu_util.print_gpu_availability()
#app = Train()
#app.train()