1
0

Compare commits

..

4 Commits

9 changed files with 2050 additions and 34 deletions

10
mnist/.gitignore vendored Normal file
View File

@@ -0,0 +1,10 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv

1
mnist/.python-version Normal file
View File

@@ -0,0 +1 @@
3.11

0
mnist/README.md Normal file
View File

27
mnist/example.py Normal file
View File

@@ -0,0 +1,27 @@
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
def main():
dataset = load_dataset('parquet', data_files={
'train': r"D:\AiData\Dataset\mnist\mnist\train-00000-of-00001.parquet",
'test': r"D:\AiData\Dataset\mnist\mnist\test-00000-of-00001.parquet",
})
train_dataset = dataset['train']
first_sample = train_dataset[0]
print("Label:", first_sample['label'])
image = first_sample['image']
image_array = np.array(image)
print("Image shape:", image_array.shape)
plt.imshow(image_array, cmap='gray')
plt.show()
print("First few rows of pixel values:")
print(image_array[:5, :5])
if __name__ == "__main__":
main()

196
mnist/main.py Normal file
View File

@@ -0,0 +1,196 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
class SimpleMLP(nn.Module):
def __init__(self, input_size=784, hidden_size=128, num_classes=10):
super(SimpleMLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def preprocess_data(dataset):
"""将数据集转换为PyTorch张量格式"""
def transform_sample(example):
# 转换图像:归一化并转为 float32
image = np.array(example['image']).astype(np.float32) / 255.0
# 注意:这里返回 numpy array稍后统一转为 tensor
return {
'image': image, # 保持为 numpy array
'label': example['label']
}
# 先应用转换
dataset = dataset.map(transform_sample, remove_columns=dataset.column_names)
# 关键:设置格式为 "torch",并指定列类型
dataset = dataset.with_format(
"torch",
columns=["image", "label"],
output_all_columns=False
)
return dataset
# 不再需要自定义 collate_fn
# 因为 with_format("torch") 后DataLoader 会自动处理张量批处理
def train_model(model, train_loader, test_loader, num_epochs=10, learning_rate=0.001):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.to(device)
model.train()
train_losses = []
train_accuracies = []
test_accuracies = []
for epoch in range(num_epochs):
epoch_start_time = time.time()
running_loss = 0.0
correct_train = 0
total_train = 0
for batch_idx, batch in enumerate(train_loader):
# batch 是一个字典:{'image': tensor, 'label': tensor}
images = batch['image'].to(device)
labels = batch['label'].to(device)
# 确保图像是正确的形状 (B, 1, 28, 28)
if images.dim() == 3:
images = images.unsqueeze(1) # 添加通道维度
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total_train += labels.size(0)
correct_train += (predicted == labels).sum().item()
if batch_idx % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], '
f'Loss: {loss.item():.4f}')
train_accuracy = 100 * correct_train / total_train
avg_loss = running_loss / len(train_loader)
test_accuracy = evaluate_model(model, test_loader)
epoch_time = time.time() - epoch_start_time
train_losses.append(avg_loss)
train_accuracies.append(train_accuracy)
test_accuracies.append(test_accuracy)
print(f'Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.2f}s')
print(f'Train Loss: {avg_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
f'Test Accuracy: {test_accuracy:.2f}%')
print('-' * 60)
return train_losses, train_accuracies, test_accuracies
def evaluate_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
images = batch['image'].to(device)
labels = batch['label'].to(device)
if images.dim() == 3:
images = images.unsqueeze(1)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
model.train()
return 100 * correct / total
def main():
dataset = load_dataset('parquet', data_files={
'train': r"D:\AiData\Dataset\mnist\mnist\train-00000-of-00001.parquet",
'test': r"D:\AiData\Dataset\mnist\mnist\test-00000-of-00001.parquet",
})
print("Preprocessing training data...")
train_dataset = preprocess_data(dataset['train'])
print("Preprocessing test data...")
test_dataset = preprocess_data(dataset['test'])
batch_size = 64
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0
)
model = SimpleMLP()
print(f"Model: {model}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print("\nStarting training...")
train_losses, train_accuracies, test_accuracies = train_model(
model, train_loader, test_loader, num_epochs=10, learning_rate=0.001
)
# 绘制训练曲线
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train')
plt.plot(test_accuracies, label='Test')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(train_accuracies, 'b-', label='Train Accuracy')
plt.plot(test_accuracies, 'r--', label='Test Accuracy')
plt.title('Train vs Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.show()
final_test_acc = evaluate_model(model, test_loader)
print(f"\nFinal Test Accuracy: {final_test_acc:.2f}%")
if __name__ == "__main__":
main()

11
mnist/pyproject.toml Normal file
View File

@@ -0,0 +1,11 @@
[project]
name = "mnist"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"datasets>=4.3.0",
"matplotlib>=3.10.7",
"numpy>=2.3.4",
]

1731
mnist/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@
# All image files # All image files
*.jpg *.jpg
*.jpeg
*.png *.png
*.webp *.webp

View File

@@ -47,24 +47,46 @@ def _uniform_car_plate(img: cv.typing.MatLike) -> cv.typing.MatLike:
@dataclass @dataclass
class CarPlateHsvBoundary: class CarPlateHsvBoundary:
"""HSV boundary for car plate color detection."""
lower_bound: cv.typing.MatLike lower_bound: cv.typing.MatLike
"""Lower bound of HSV range for car plate color detection."""
upper_bound: cv.typing.MatLike upper_bound: cv.typing.MatLike
"""Upper bound of HSV range for car plate color detection."""
need_revert: bool
"""是否取反黑白颜色,因为蓝牌和黄牌的操作正好是反的"""
CAR_PLATE_HSV_BOUNDARIES: tuple[CarPlateHsvBoundary, ...] = ( CAR_PLATE_HSV_BOUNDARIES: tuple[CarPlateHsvBoundary, ...] = (
# 中国蓝牌 HSV 范围 # 中国蓝牌 HSV 范围
CarPlateHsvBoundary(np.array([100, 80, 60]), np.array([130, 255, 255])), CarPlateHsvBoundary(np.array([100, 80, 60]), np.array([130, 255, 255]), True),
# 中国绿牌 HSV 范围 # 中国绿牌 HSV 范围
CarPlateHsvBoundary(np.array([35, 43, 46]), np.array([99, 255, 255])), CarPlateHsvBoundary(np.array([35, 43, 46]), np.array([99, 255, 255]), False),
# 中国黄牌 HSV 范围 # 中国黄牌 HSV 范围
CarPlateHsvBoundary(np.array([32, 43, 46]), np.array([68, 255, 255])), CarPlateHsvBoundary(np.array([16, 43, 46]), np.array([34, 255, 255]), False),
) )
@dataclass
class CarPlateMask:
"""Car plate mask result."""
mask: cv.typing.MatLike
"""The masked image in U8 format."""
need_revert: bool
"""是否对颜色取反与CarPlateHsvBoundary中的同名字段含义一致"""
def _batchly_mask_car_plate( def _batchly_mask_car_plate(
hsv: cv.typing.MatLike, hsv: cv.typing.MatLike,
) -> typing.Iterator[cv.typing.MatLike]: ) -> typing.Iterator[CarPlateMask]:
""" """ """
Iterate over each car plate HSV boundary and apply mask to the given HSV image.
:param hsv: The HSV image to apply mask.
:return: An iterator of CarPlateMask.
"""
for boundary in CAR_PLATE_HSV_BOUNDARIES: for boundary in CAR_PLATE_HSV_BOUNDARIES:
# 以给定HSV范围检测符合该颜色的位置 # 以给定HSV范围检测符合该颜色的位置
mask = cv.inRange(hsv, boundary.lower_bound, boundary.upper_bound) mask = cv.inRange(hsv, boundary.lower_bound, boundary.upper_bound)
@@ -76,32 +98,50 @@ def _batchly_mask_car_plate(
mask = cv.morphologyEx(mask, cv.MORPH_OPEN, kernel_open) mask = cv.morphologyEx(mask, cv.MORPH_OPEN, kernel_open)
# Return value # Return value
yield mask yield CarPlateMask(mask, boundary.need_revert)
@dataclass @dataclass
class CarPlateRegion: class CarPlateRegion:
"""Car plate region result."""
x: int x: int
y: int y: int
w: int w: int
h: int h: int
need_revert: bool
"""是否对颜色取反与CarPlateHsvBoundary中的同名字段含义一致"""
MIN_AREA: float = 3000 MIN_AREA: float = 3000
"""Minimum area for car plate region."""
MIN_ASPECT_RATIO: float = 1.5 MIN_ASPECT_RATIO: float = 1.5
"""Minimum aspect ratio for car plate region."""
MAX_ASPECT_RATIO: float = 6.0 MAX_ASPECT_RATIO: float = 6.0
"""Maximum aspect ratio for car plate region."""
BEST_ASPECT_RATIO: float = 3.5 BEST_ASPECT_RATIO: float = 3.5
"""Best aspect ratio for car plate region."""
def _analyse_car_plate_connection( def _analyse_car_plate_connection(
mask: cv.typing.MatLike, masks: typing.Iterator[CarPlateMask],
) -> typing.Optional[CarPlateRegion]: ) -> typing.Optional[CarPlateRegion]:
# 连通域分析,筛选最符合车牌长宽比的区域 """
num_labels, labels, stats, _ = cv.connectedComponentsWithStats(mask, connectivity=8) Analyse car plate connection in given masks.
:param masks: An iterator of CarPlateMask to analyse.
:return: The car plate region if succeed, otherwise None.
"""
best: typing.Optional[CarPlateRegion] = None best: typing.Optional[CarPlateRegion] = None
best_score = 0 best_score = 0
for mask in masks:
# 连通域分析,筛选最符合车牌长宽比的区域
num_labels, labels, stats, _ = cv.connectedComponentsWithStats(
mask.mask, connectivity=8
)
for i in range(1, num_labels): for i in range(1, num_labels):
x, y, w, h, area = stats[i] x, y, w, h, area = stats[i]
# 检查面积 # 检查面积
@@ -113,7 +153,7 @@ def _analyse_car_plate_connection(
score = area * (1 - abs(ratio - BEST_ASPECT_RATIO) / BEST_ASPECT_RATIO) score = area * (1 - abs(ratio - BEST_ASPECT_RATIO) / BEST_ASPECT_RATIO)
if score > best_score: if score > best_score:
best_score = score best_score = score
best = CarPlateRegion(x, y, w, h) best = CarPlateRegion(x, y, w, h, mask.need_revert)
return best return best
@@ -125,21 +165,17 @@ def extract_car_plate(img: cv.typing.MatLike) -> typing.Optional[cv.typing.MatLi
:param img: The image containing car plate in BGR format. :param img: The image containing car plate in BGR format.
:return: The image of binary car plate in U8 format if succeed, otherwise None. :return: The image of binary car plate in U8 format if succeed, otherwise None.
""" """
# 统一图片大小
img = _uniform_car_plate(img) img = _uniform_car_plate(img)
# 转换到HSV空间 # 转换到HSV空间
hsv = cv.cvtColor(img, cv.COLOR_BGR2HSV) hsv = cv.cvtColor(img, cv.COLOR_BGR2HSV)
# 利用车牌颜色在 HSV 空间定位车牌 # 利用车牌颜色在 HSV 空间定位车牌
candidate: typing.Optional[CarPlateRegion] = None masks = _batchly_mask_car_plate(hsv)
for mask in _batchly_mask_car_plate(hsv): candidate = _analyse_car_plate_connection(masks)
# 连通域分析,筛选最符合车牌长宽比的区域作为车牌
candidate = _analyse_car_plate_connection(mask)
# 找到任意一个就退出
if candidate is not None: break
if candidate is None: if candidate is None:
logging.error('Can not find any car plate.') logging.error("Can not find any car plate.")
return None return None
# 稍微扩边获取最终车牌区域 # 稍微扩边获取最终车牌区域
@@ -149,7 +185,7 @@ def extract_car_plate(img: cv.typing.MatLike) -> typing.Optional[cv.typing.MatLi
y1 = max(candidate.y - pad, 0) y1 = max(candidate.y - pad, 0)
x2 = min(candidate.x + candidate.w + pad, w_img) x2 = min(candidate.x + candidate.w + pad, w_img)
y2 = min(candidate.y + candidate.h + pad, h_img) y2 = min(candidate.y + candidate.h + pad, h_img)
logging.info(f'车牌区域: x={x1}, y={y1}, w={x2 - x1}, h={y2 - y1}') logging.info(f"车牌区域: x={x1}, y={y1}, w={x2 - x1}, h={y2 - y1}")
# # 在原图上标记(仅供调试) # # 在原图上标记(仅供调试)
# debug = img.copy() # debug = img.copy()
@@ -165,17 +201,20 @@ def extract_car_plate(img: cv.typing.MatLike) -> typing.Optional[cv.typing.MatLi
# Otsu 自动阈值,得到白字黑底,再取反 → 黑字白底 # Otsu 自动阈值,得到白字黑底,再取反 → 黑字白底
_, binary_otsu = cv.threshold(blurred, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU) _, binary_otsu = cv.threshold(blurred, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU)
# 反转:字符变黑,背景变白 # 反转:字符变黑,背景变白
if candidate.need_revert:
binary = cv.bitwise_not(binary_otsu) binary = cv.bitwise_not(binary_otsu)
else:
binary = binary_otsu
# 去除小噪点(开运算) # 去除小噪点(开运算)
kernel_denoise = cv.getStructuringElement(cv.MORPH_RECT, (2, 2)) kernel_denoise = cv.getStructuringElement(cv.MORPH_RECT, (2, 2))
binary = cv.morphologyEx(binary, cv.MORPH_OPEN, kernel_denoise) binary = cv.morphologyEx(binary, cv.MORPH_OPEN, kernel_denoise)
#return binary return binary
# cv.imwrite('./plate_binary.png', binary) # cv.imwrite('./plate_binary.png', binary)
# print("二值化结果已保存: plate_binary.png") # print("二值化结果已保存: plate_binary.png")
# ── 4. 叠加边框轮廓(细化文字边缘,参考效果图)───────────────────── # 叠加边框轮廓(细化文字边缘,参考效果图)
# Canny 边缘叠加让效果更接近参考图 # Canny 边缘叠加让效果更接近参考图
edges = cv.Canny(blurred, 40, 120) edges = cv.Canny(blurred, 40, 120)
edges_inv = cv.bitwise_not(edges) # 边缘→黑色 edges_inv = cv.bitwise_not(edges) # 边缘→黑色