1
0
Files
ai-school/dl-exp/exp2/modified/dataset.py

81 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import numpy
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2 as tvtrans
import settings
class MnistDataset(Dataset):
"""适配PyTorch的自定义Dataset用于加载MNIST数据。"""
shape: int
transform: tvtrans.Transform
images_data: numpy.ndarray
labels_data: torch.Tensor
def __init__(self, images: numpy.ndarray, labels: numpy.ndarray, transform: tvtrans.Transform):
images_len: int = images.shape[0]
labels_len: int = labels.shape[0]
assert (images_len == labels_len)
self.shape = images_len
self.images_data = images
self.labels_data = torch.from_numpy(labels)
self.transform = transform
def __getitem__(self, index):
return self.transform(self.images_data[index]), self.labels_data[index]
def __len__(self):
return self.shape
class MnistDataLoaders:
"""包含适配PyTorch的训练数据Loader和测试数据Loader的类。"""
train_loader: DataLoader
test_loader: DataLoader
def __init__(self, batch_size: int):
dataset = numpy.load(settings.MNIST_DATASET_PATH)
# 所有图片均为黑底白字
# 6万张训练图片60000x28x28。标签只有第一维。
train_images: numpy.ndarray = dataset['x_train']
train_labels: numpy.ndarray = dataset['y_train']
# 1万张测试图片10000x28x28。标签只有第一维。
test_images: numpy.ndarray = dataset['x_test']
test_labels: numpy.ndarray = dataset['y_test']
# 定义数据转换器
trans = tvtrans.Compose([
# 从uint8转换为float32并自动归一化到0-1区间
# YYC MARK: 下面这个被标outdated了换下面两个替代。
# tvtrans.ToTensor(),
tvtrans.ToImage(),
tvtrans.ToDtype(torch.float32, scale=True),
# 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度
# YYC MARK: 上面这两步已经帮我们自动挤出那个灰度通道了。
# tvtrans.Lambda(lambda x: x.unsqueeze(-1))
# 这个特定的标准化参数 (0.1307, 0.3081) 是 MNIST 数据集的标准化参数这些数值是MNIST训练集的全局均值和标准差。
# 这种标准化有助于模型训练时的数值稳定性和收敛速度。
# YYC MARK: 但我不想用,反正最后训练的也收敛。
# tvtrans.Normalize((0.1307,), (0.3081,)),
])
# 创建数据集
train_dataset = MnistDataset(train_images, train_labels,
transform=trans)
test_dataset = MnistDataset(test_images, test_labels,
transform=trans)
# 赋值到自身
self.train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=False)
self.test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)