1
0

fix exp2 pytorch rewrite fatal train issue

This commit is contained in:
2025-11-30 22:01:56 +08:00
parent 48fcdfcc80
commit 43b807679f
13 changed files with 738 additions and 112 deletions

View File

@@ -7,10 +7,11 @@ import ignite.engine
import ignite.metrics
from ignite.engine import Engine, Events
from ignite.handlers.tqdm_logger import ProgressBar
from mnist import CNN, MnistDataSource
from dataset import MnistDataSource
from model import Cnn
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
import pytorch_gpu_utils
class Trainer:
@@ -19,12 +20,12 @@ class Trainer:
device: torch.device
data_source: MnistDataSource
model: CNN
model: Cnn
def __init__(self):
self.device = gpu_utils.get_gpu_device()
self.device = pytorch_gpu_utils.get_gpu_device()
self.data_source = MnistDataSource(Trainer.N_BATCH_SIZE)
self.model = CNN().to(self.device)
self.model = Cnn().to(self.device)
# 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
torchinfo.summary(self.model, (Trainer.N_BATCH_SIZE, 1, 28, 28))
@@ -101,7 +102,7 @@ def main():
# device: torch.device
# data_source: MnistDataSource
# model: CNN
# model: Cnn
# trainer: Engine
# train_evaluator: Engine
@@ -109,7 +110,7 @@ def main():
# def __init__(self):
# self.device = gpu_utils.get_gpu_device()
# self.model = CNN().to(self.device)
# self.model = Cnn().to(self.device)
# self.data_source = MnistDataSource(batch_size=N_BATCH_SIZE)
# # 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
# torchinfo.summary(self.model, (N_BATCH_SIZE, 1, 28, 28))
@@ -188,5 +189,5 @@ def main():
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
pytorch_gpu_utils.print_gpu_availability()
main()