1
0
Files
ai-school/mnist/main.py
2026-04-15 12:26:41 +08:00

196 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
class SimpleMLP(nn.Module):
def __init__(self, input_size=784, hidden_size=128, num_classes=10):
super(SimpleMLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def preprocess_data(dataset):
"""将数据集转换为PyTorch张量格式"""
def transform_sample(example):
# 转换图像:归一化并转为 float32
image = np.array(example['image']).astype(np.float32) / 255.0
# 注意:这里返回 numpy array稍后统一转为 tensor
return {
'image': image, # 保持为 numpy array
'label': example['label']
}
# 先应用转换
dataset = dataset.map(transform_sample, remove_columns=dataset.column_names)
# 关键:设置格式为 "torch",并指定列类型
dataset = dataset.with_format(
"torch",
columns=["image", "label"],
output_all_columns=False
)
return dataset
# 不再需要自定义 collate_fn
# 因为 with_format("torch") 后DataLoader 会自动处理张量批处理
def train_model(model, train_loader, test_loader, num_epochs=10, learning_rate=0.001):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.to(device)
model.train()
train_losses = []
train_accuracies = []
test_accuracies = []
for epoch in range(num_epochs):
epoch_start_time = time.time()
running_loss = 0.0
correct_train = 0
total_train = 0
for batch_idx, batch in enumerate(train_loader):
# batch 是一个字典:{'image': tensor, 'label': tensor}
images = batch['image'].to(device)
labels = batch['label'].to(device)
# 确保图像是正确的形状 (B, 1, 28, 28)
if images.dim() == 3:
images = images.unsqueeze(1) # 添加通道维度
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total_train += labels.size(0)
correct_train += (predicted == labels).sum().item()
if batch_idx % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], '
f'Loss: {loss.item():.4f}')
train_accuracy = 100 * correct_train / total_train
avg_loss = running_loss / len(train_loader)
test_accuracy = evaluate_model(model, test_loader)
epoch_time = time.time() - epoch_start_time
train_losses.append(avg_loss)
train_accuracies.append(train_accuracy)
test_accuracies.append(test_accuracy)
print(f'Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.2f}s')
print(f'Train Loss: {avg_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
f'Test Accuracy: {test_accuracy:.2f}%')
print('-' * 60)
return train_losses, train_accuracies, test_accuracies
def evaluate_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
images = batch['image'].to(device)
labels = batch['label'].to(device)
if images.dim() == 3:
images = images.unsqueeze(1)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
model.train()
return 100 * correct / total
def main():
dataset = load_dataset('parquet', data_files={
'train': r"D:\AiData\Dataset\mnist\mnist\train-00000-of-00001.parquet",
'test': r"D:\AiData\Dataset\mnist\mnist\test-00000-of-00001.parquet",
})
print("Preprocessing training data...")
train_dataset = preprocess_data(dataset['train'])
print("Preprocessing test data...")
test_dataset = preprocess_data(dataset['test'])
batch_size = 64
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
model = SimpleMLP()
print(f"Model: {model}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print("\nStarting training...")
train_losses, train_accuracies, test_accuracies = train_model(
model, train_loader, test_loader, num_epochs=10, learning_rate=0.001
)
# 绘制训练曲线
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train')
plt.plot(test_accuracies, label='Test')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(train_accuracies, 'b-', label='Train Accuracy')
plt.plot(test_accuracies, 'r--', label='Test Accuracy')
plt.title('Train vs Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.show()
final_test_acc = evaluate_model(model, test_loader)
print(f"\nFinal Test Accuracy: {final_test_acc:.2f}%")
if __name__ == "__main__":
main()