import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from datasets import load_dataset import numpy as np import matplotlib.pyplot as plt import time device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") class SimpleMLP(nn.Module): def __init__(self, input_size=784, hidden_size=128, num_classes=10): super(SimpleMLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, num_classes) def forward(self, x): x = x.view(x.size(0), -1) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x def preprocess_data(dataset): """将数据集转换为PyTorch张量格式""" def transform_sample(example): # 转换图像:归一化并转为 float32 image = np.array(example['image']).astype(np.float32) / 255.0 # 注意:这里返回 numpy array,稍后统一转为 tensor return { 'image': image, # 保持为 numpy array 'label': example['label'] } # 先应用转换 dataset = dataset.map(transform_sample, remove_columns=dataset.column_names) # 关键:设置格式为 "torch",并指定列类型 dataset = dataset.with_format( "torch", columns=["image", "label"], output_all_columns=False ) return dataset # 不再需要自定义 collate_fn! # 因为 with_format("torch") 后,DataLoader 会自动处理张量批处理 def train_model(model, train_loader, test_loader, num_epochs=10, learning_rate=0.001): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) model.to(device) model.train() train_losses = [] train_accuracies = [] test_accuracies = [] for epoch in range(num_epochs): epoch_start_time = time.time() running_loss = 0.0 correct_train = 0 total_train = 0 for batch_idx, batch in enumerate(train_loader): # batch 是一个字典:{'image': tensor, 'label': tensor} images = batch['image'].to(device) labels = batch['label'].to(device) # 确保图像是正确的形状 (B, 1, 28, 28) if images.dim() == 3: images = images.unsqueeze(1) # 添加通道维度 outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total_train += labels.size(0) correct_train += (predicted == labels).sum().item() if batch_idx % 100 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], ' f'Loss: {loss.item():.4f}') train_accuracy = 100 * correct_train / total_train avg_loss = running_loss / len(train_loader) test_accuracy = evaluate_model(model, test_loader) epoch_time = time.time() - epoch_start_time train_losses.append(avg_loss) train_accuracies.append(train_accuracy) test_accuracies.append(test_accuracy) print(f'Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.2f}s') print(f'Train Loss: {avg_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, ' f'Test Accuracy: {test_accuracy:.2f}%') print('-' * 60) return train_losses, train_accuracies, test_accuracies def evaluate_model(model, test_loader): model.eval() correct = 0 total = 0 with torch.no_grad(): for batch in test_loader: images = batch['image'].to(device) labels = batch['label'].to(device) if images.dim() == 3: images = images.unsqueeze(1) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() model.train() return 100 * correct / total def main(): dataset = load_dataset('parquet', data_files={ 'train': r"D:\AiData\Dataset\mnist\mnist\train-00000-of-00001.parquet", 'test': r"D:\AiData\Dataset\mnist\mnist\test-00000-of-00001.parquet", }) print("Preprocessing training data...") train_dataset = preprocess_data(dataset['train']) print("Preprocessing test data...") test_dataset = preprocess_data(dataset['test']) batch_size = 64 train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=0 ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=0 ) model = SimpleMLP() print(f"Model: {model}") print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") print("\nStarting training...") train_losses, train_accuracies, test_accuracies = train_model( model, train_loader, test_loader, num_epochs=10, learning_rate=0.001 ) # 绘制训练曲线 plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.plot(train_losses) plt.title('Training Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.subplot(1, 3, 2) plt.plot(train_accuracies, label='Train') plt.plot(test_accuracies, label='Test') plt.title('Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.legend() plt.subplot(1, 3, 3) plt.plot(train_accuracies, 'b-', label='Train Accuracy') plt.plot(test_accuracies, 'r--', label='Test Accuracy') plt.title('Train vs Test Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.legend() plt.tight_layout() plt.show() final_test_acc = evaluate_model(model, test_loader) print(f"\nFinal Test Accuracy: {final_test_acc:.2f}%") if __name__ == "__main__": main()