1
0
Files
ai-school/exp2/modified/train.py
2025-11-24 14:20:38 +08:00

140 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import numpy
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torch.nn.functional as F
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
class CNN(torch.nn.Module):
"""卷积神经网络模型"""
def __init__(self):
super(CNN, self).__init__()
# 使用Ceil模式设置MaxPooling因为tensorflow默认是这个模式。
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3))
self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2), ceil_mode=True)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2), ceil_mode=True)
self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3))
self.flatten = torch.nn.Flatten()
# 28x28过第一轮卷积后变为26x26过第一轮池化后变为13x13
# 过第二轮卷积后变为11x11过第二轮池化后变为5x5
# 过第三轮卷积后变为3x3。
# 最后一轮卷积核个数为64。
self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
self.fc2 = torch.nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
class MnistDataset(Dataset):
"""用于加载Mnist的自定义数据集"""
shape: int
x_data: torch.Tensor
y_data: torch.Tensor
def __init__(self, x_data: torch.Tensor, y_data: torch.Tensor):
x_len = x_data.shape[0]
y_len = y_data.shape[0]
assert (x_len == y_len)
self.shape = x_len
self.x_data = x_data
self.y_data = y_data
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.shape
class DataSource:
"""用于读取MNIST数据的数据读取器"""
train_data: DataLoader
test_data: DataLoader
def __init__(self, batch_size: int):
datasets_path = Path(
__file__).resolve().parent.parent / 'datasets' / 'mnist.npz'
datasets = numpy.load(datasets_path)
# 6万张训练图片60000x28x28。标签只有第一维。
train_images = torch.from_numpy(datasets['x_train'])
train_label = torch.from_numpy(datasets['y_train'])
# 1万张测试图片10000x28x28。标签只有第一维。
test_images = torch.from_numpy(datasets['x_test'])
test_label = torch.from_numpy(datasets['y_test'])
# 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度
train_images.unsqueeze(-1)
test_images.unsqueeze(-1)
# 像素值归一化
train_images /= 255.0
test_images /= 255.0
# 创建数据集
train_dataset = MnistDataset(train_images, train_label)
test_dataset = MnistDataset(test_images, test_label)
# 赋值到自身
self.train_data = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
self.test_data = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
def main():
n_epoch = 5
n_batch_size = 25
device = gpu_utils.get_gpu_device()
data_source = DataSource(n_batch_size)
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters())
loss_func = torch.nn.CrossEntropyLoss()
for epoch in range(n_epoch):
cnn.train()
batch_images: torch.Tensor
batch_labels: torch.Tensor
for batch_index, (batch_images, batch_labels) in enumerate(data_source.train_data):
gpu_images = batch_images.to(device)
gpu_labels = batch_labels.to(device)
optimizer.zero_grad()
prediction: torch.Tensor = cnn(gpu_images)
loss: torch.Tensor = loss_func(prediction, gpu_labels)
loss.backward()
optimizer.step()
loss_showcase = loss.item()
print(f'Epoch: {epoch+1}, Batch: {batch_index}, Loss: {loss.item():.4f}')
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
main()