fix exp2 pytorch rewrite fatal train issue
This commit is contained in:
@@ -7,10 +7,11 @@ import ignite.engine
|
||||
import ignite.metrics
|
||||
from ignite.engine import Engine, Events
|
||||
from ignite.handlers.tqdm_logger import ProgressBar
|
||||
from mnist import CNN, MnistDataSource
|
||||
from dataset import MnistDataSource
|
||||
from model import Cnn
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||
import gpu_utils
|
||||
import pytorch_gpu_utils
|
||||
|
||||
|
||||
class Trainer:
|
||||
@@ -19,12 +20,12 @@ class Trainer:
|
||||
|
||||
device: torch.device
|
||||
data_source: MnistDataSource
|
||||
model: CNN
|
||||
model: Cnn
|
||||
|
||||
def __init__(self):
|
||||
self.device = gpu_utils.get_gpu_device()
|
||||
self.device = pytorch_gpu_utils.get_gpu_device()
|
||||
self.data_source = MnistDataSource(Trainer.N_BATCH_SIZE)
|
||||
self.model = CNN().to(self.device)
|
||||
self.model = Cnn().to(self.device)
|
||||
# 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。
|
||||
torchinfo.summary(self.model, (Trainer.N_BATCH_SIZE, 1, 28, 28))
|
||||
|
||||
@@ -101,7 +102,7 @@ def main():
|
||||
|
||||
# device: torch.device
|
||||
# data_source: MnistDataSource
|
||||
# model: CNN
|
||||
# model: Cnn
|
||||
|
||||
# trainer: Engine
|
||||
# train_evaluator: Engine
|
||||
@@ -109,7 +110,7 @@ def main():
|
||||
|
||||
# def __init__(self):
|
||||
# self.device = gpu_utils.get_gpu_device()
|
||||
# self.model = CNN().to(self.device)
|
||||
# self.model = Cnn().to(self.device)
|
||||
# self.data_source = MnistDataSource(batch_size=N_BATCH_SIZE)
|
||||
# # 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。
|
||||
# torchinfo.summary(self.model, (N_BATCH_SIZE, 1, 28, 28))
|
||||
@@ -188,5 +189,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gpu_utils.print_gpu_availability()
|
||||
pytorch_gpu_utils.print_gpu_availability()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user