1
0
Files
ai-school/exp2/modified/train.py
2025-11-24 21:02:44 +08:00

189 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import sys
import typing
import numpy
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2 as tvtrans
import matplotlib.pyplot as plt
import torch.nn.functional as F
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils
class CNN(torch.nn.Module):
"""卷积神经网络模型"""
def __init__(self):
super(CNN, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3))
self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3))
self.flatten = torch.nn.Flatten()
# 28x28过第一轮卷积后变为26x26过第一轮池化后变为13x13
# 过第二轮卷积后变为11x11过第二轮池化后变为5x5
# 过第三轮卷积后变为3x3。
# 最后一轮卷积核个数为64。
self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
self.fc2 = torch.nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
class MnistDataset(Dataset):
"""用于加载Mnist的自定义数据集"""
shape: int
transform: tvtrans.Transform
images_data: numpy.ndarray
labels_data: torch.Tensor
def __init__(self, images: numpy.ndarray, labels: numpy.ndarray, transform: tvtrans.Transform):
images_len: int = images.size(0)
labels_len: int = labels.size(0)
assert (images_len == labels_len)
self.shape = images_len
self.images_data = images
self.labels_data = torch.from_numpy(labels)
self.transform = transform
def __getitem__(self, index):
return self.transform(self.images_data[index]), self.labels_data[index]
def __len__(self):
return self.shape
class DataSource:
"""用于读取MNIST数据的数据读取器"""
train_data: DataLoader
test_data: DataLoader
def __init__(self, batch_size: int):
datasets_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz'
datasets = numpy.load(datasets_path)
# 所有图片均为黑底白字
# 6万张训练图片60000x28x28。标签只有第一维。
train_images = datasets['x_train']
train_labels = datasets['y_train']
# 1万张测试图片10000x28x28。标签只有第一维。
test_images = datasets['x_test']
test_labels = datasets['y_test']
# 定义数据转换器
trans = tvtrans.Compose([
# 从uint8转换为float32并自动归一化到0-1区间
# tvtrans.ToTensor(),
tvtrans.ToImage(),
tvtrans.ToDtype(torch.float32, scale=True),
# 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度
#tvtrans.Lambda(lambda x: x.unsqueeze(-1))
])
# 创建数据集
train_dataset = MnistDataset(train_images,
train_labels,
transform=trans)
test_dataset = MnistDataset(test_images, test_labels, transform=trans)
# 赋值到自身
self.train_data = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=False)
self.test_data = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
class Trainer:
N_EPOCH: typing.ClassVar[int] = 5
N_BATCH_SIZE: typing.ClassVar[int] = 1000
device: torch.device
data_source: DataSource
cnn: CNN
def __init__(self):
self.device = gpu_utils.get_gpu_device()
self.data_source = DataSource(Trainer.N_BATCH_SIZE)
self.cnn = CNN().to(self.device)
def train(self):
optimizer = torch.optim.Adam(self.cnn.parameters())
loss_func = torch.nn.CrossEntropyLoss()
for epoch in range(Trainer.N_EPOCH):
self.cnn.train()
batch_images: torch.Tensor
batch_labels: torch.Tensor
for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_data):
gpu_images = batch_images.to(self.device)
gpu_labels = batch_labels.to(self.device)
optimizer.zero_grad()
prediction: torch.Tensor = self.cnn(gpu_images)
loss: torch.Tensor = loss_func(prediction, gpu_labels)
loss.backward()
optimizer.step()
if batch_index % 100 == 0:
literal_loss = loss.item()
print(f'Epoch: {epoch+1}, Batch: {batch_index}, Loss: {literal_loss:.4f}')
def save(self):
file_dir_path = Path(__file__).resolve().parent.parent / 'models'
file_dir_path.mkdir(parents=True, exist_ok=True)
file_path = file_dir_path / 'cnn.pth'
torch.save(self.cnn.state_dict(), file_path)
print(f'模型已保存至:{file_path}')
def test(self):
self.cnn.eval()
correct_sum = 0
total_sum = 0
with torch.no_grad():
for batch_images, batch_labels in self.data_source.test_data:
gpu_images = batch_images.to(self.device)
gpu_labels = batch_labels.to(self.device)
possibilities: torch.Tensor = self.cnn(gpu_images)
# 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比
# 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。
_, prediction = possibilities.max(1)
# 返回标签的个数作为这一批的总个数
total_sum += gpu_labels.size(0)
correct_sum += prediction.eq(gpu_labels).sum()
test_acc = 100. * correct_sum / total_sum
print(f"准确率: {test_acc:.4f}%,共测试了{total_sum}张图片")
def main():
trainer = Trainer()
trainer.train()
trainer.save()
trainer.test()
if __name__ == "__main__":
gpu_utils.print_gpu_availability()
main()