Compare commits
3 Commits
75f1d58161
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 2a57820d4b | |||
| 1a43580add | |||
| eb43b3df31 |
10
mnist/.gitignore
vendored
Normal file
10
mnist/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
1
mnist/.python-version
Normal file
1
mnist/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11
|
||||
0
mnist/README.md
Normal file
0
mnist/README.md
Normal file
27
mnist/example.py
Normal file
27
mnist/example.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from datasets import load_dataset
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def main():
|
||||
dataset = load_dataset('parquet', data_files={
|
||||
'train': r"D:\AiData\Dataset\mnist\mnist\train-00000-of-00001.parquet",
|
||||
'test': r"D:\AiData\Dataset\mnist\mnist\test-00000-of-00001.parquet",
|
||||
})
|
||||
train_dataset = dataset['train']
|
||||
first_sample = train_dataset[0]
|
||||
|
||||
print("Label:", first_sample['label'])
|
||||
|
||||
image = first_sample['image']
|
||||
image_array = np.array(image)
|
||||
print("Image shape:", image_array.shape)
|
||||
|
||||
plt.imshow(image_array, cmap='gray')
|
||||
plt.show()
|
||||
|
||||
print("First few rows of pixel values:")
|
||||
print(image_array[:5, :5])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
196
mnist/main.py
Normal file
196
mnist/main.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
class SimpleMLP(nn.Module):
|
||||
def __init__(self, input_size=784, hidden_size=128, num_classes=10):
|
||||
super(SimpleMLP, self).__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def preprocess_data(dataset):
|
||||
"""将数据集转换为PyTorch张量格式"""
|
||||
def transform_sample(example):
|
||||
# 转换图像:归一化并转为 float32
|
||||
image = np.array(example['image']).astype(np.float32) / 255.0
|
||||
# 注意:这里返回 numpy array,稍后统一转为 tensor
|
||||
return {
|
||||
'image': image, # 保持为 numpy array
|
||||
'label': example['label']
|
||||
}
|
||||
|
||||
# 先应用转换
|
||||
dataset = dataset.map(transform_sample, remove_columns=dataset.column_names)
|
||||
|
||||
# 关键:设置格式为 "torch",并指定列类型
|
||||
dataset = dataset.with_format(
|
||||
"torch",
|
||||
columns=["image", "label"],
|
||||
output_all_columns=False
|
||||
)
|
||||
return dataset
|
||||
|
||||
# 不再需要自定义 collate_fn!
|
||||
# 因为 with_format("torch") 后,DataLoader 会自动处理张量批处理
|
||||
|
||||
def train_model(model, train_loader, test_loader, num_epochs=10, learning_rate=0.001):
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
train_losses = []
|
||||
train_accuracies = []
|
||||
test_accuracies = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
epoch_start_time = time.time()
|
||||
running_loss = 0.0
|
||||
correct_train = 0
|
||||
total_train = 0
|
||||
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
# batch 是一个字典:{'image': tensor, 'label': tensor}
|
||||
images = batch['image'].to(device)
|
||||
labels = batch['label'].to(device)
|
||||
|
||||
# 确保图像是正确的形状 (B, 1, 28, 28)
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(1) # 添加通道维度
|
||||
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total_train += labels.size(0)
|
||||
correct_train += (predicted == labels).sum().item()
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], '
|
||||
f'Loss: {loss.item():.4f}')
|
||||
|
||||
train_accuracy = 100 * correct_train / total_train
|
||||
avg_loss = running_loss / len(train_loader)
|
||||
test_accuracy = evaluate_model(model, test_loader)
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
|
||||
train_losses.append(avg_loss)
|
||||
train_accuracies.append(train_accuracy)
|
||||
test_accuracies.append(test_accuracy)
|
||||
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.2f}s')
|
||||
print(f'Train Loss: {avg_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
|
||||
f'Test Accuracy: {test_accuracy:.2f}%')
|
||||
print('-' * 60)
|
||||
|
||||
return train_losses, train_accuracies, test_accuracies
|
||||
|
||||
def evaluate_model(model, test_loader):
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in test_loader:
|
||||
images = batch['image'].to(device)
|
||||
labels = batch['label'].to(device)
|
||||
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(1)
|
||||
|
||||
outputs = model(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
model.train()
|
||||
return 100 * correct / total
|
||||
|
||||
def main():
|
||||
dataset = load_dataset('parquet', data_files={
|
||||
'train': r"D:\AiData\Dataset\mnist\mnist\train-00000-of-00001.parquet",
|
||||
'test': r"D:\AiData\Dataset\mnist\mnist\test-00000-of-00001.parquet",
|
||||
})
|
||||
|
||||
print("Preprocessing training data...")
|
||||
train_dataset = preprocess_data(dataset['train'])
|
||||
print("Preprocessing test data...")
|
||||
test_dataset = preprocess_data(dataset['test'])
|
||||
|
||||
batch_size = 64
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
model = SimpleMLP()
|
||||
print(f"Model: {model}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
print("\nStarting training...")
|
||||
train_losses, train_accuracies, test_accuracies = train_model(
|
||||
model, train_loader, test_loader, num_epochs=10, learning_rate=0.001
|
||||
)
|
||||
|
||||
# 绘制训练曲线
|
||||
plt.figure(figsize=(15, 5))
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.plot(train_losses)
|
||||
plt.title('Training Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.plot(train_accuracies, label='Train')
|
||||
plt.plot(test_accuracies, label='Test')
|
||||
plt.title('Accuracy')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Accuracy (%)')
|
||||
plt.legend()
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(train_accuracies, 'b-', label='Train Accuracy')
|
||||
plt.plot(test_accuracies, 'r--', label='Test Accuracy')
|
||||
plt.title('Train vs Test Accuracy')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Accuracy (%)')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
final_test_acc = evaluate_model(model, test_loader)
|
||||
print(f"\nFinal Test Accuracy: {final_test_acc:.2f}%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
11
mnist/pyproject.toml
Normal file
11
mnist/pyproject.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[project]
|
||||
name = "mnist"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"datasets>=4.3.0",
|
||||
"matplotlib>=3.10.7",
|
||||
"numpy>=2.3.4",
|
||||
]
|
||||
1731
mnist/uv.lock
generated
Normal file
1731
mnist/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
1
mv-and-ip/.gitignore
vendored
1
mv-and-ip/.gitignore
vendored
@@ -4,6 +4,7 @@
|
||||
|
||||
# All image files
|
||||
*.jpg
|
||||
*.jpeg
|
||||
*.png
|
||||
*.webp
|
||||
|
||||
|
||||
@@ -47,24 +47,46 @@ def _uniform_car_plate(img: cv.typing.MatLike) -> cv.typing.MatLike:
|
||||
|
||||
@dataclass
|
||||
class CarPlateHsvBoundary:
|
||||
"""HSV boundary for car plate color detection."""
|
||||
|
||||
lower_bound: cv.typing.MatLike
|
||||
"""Lower bound of HSV range for car plate color detection."""
|
||||
upper_bound: cv.typing.MatLike
|
||||
"""Upper bound of HSV range for car plate color detection."""
|
||||
need_revert: bool
|
||||
"""是否取反黑白颜色,因为蓝牌和黄牌的操作正好是反的"""
|
||||
|
||||
|
||||
CAR_PLATE_HSV_BOUNDARIES: tuple[CarPlateHsvBoundary, ...] = (
|
||||
# 中国蓝牌 HSV 范围
|
||||
CarPlateHsvBoundary(np.array([100, 80, 60]), np.array([130, 255, 255])),
|
||||
CarPlateHsvBoundary(np.array([100, 80, 60]), np.array([130, 255, 255]), True),
|
||||
# 中国绿牌 HSV 范围
|
||||
CarPlateHsvBoundary(np.array([35, 43, 46]), np.array([99, 255, 255])),
|
||||
CarPlateHsvBoundary(np.array([35, 43, 46]), np.array([99, 255, 255]), False),
|
||||
# 中国黄牌 HSV 范围
|
||||
CarPlateHsvBoundary(np.array([32, 43, 46]), np.array([68, 255, 255])),
|
||||
CarPlateHsvBoundary(np.array([16, 43, 46]), np.array([34, 255, 255]), False),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CarPlateMask:
|
||||
"""Car plate mask result."""
|
||||
|
||||
mask: cv.typing.MatLike
|
||||
"""The masked image in U8 format."""
|
||||
need_revert: bool
|
||||
"""是否对颜色取反,与CarPlateHsvBoundary中的同名字段含义一致"""
|
||||
|
||||
|
||||
def _batchly_mask_car_plate(
|
||||
hsv: cv.typing.MatLike,
|
||||
) -> typing.Iterator[cv.typing.MatLike]:
|
||||
""" """
|
||||
) -> typing.Iterator[CarPlateMask]:
|
||||
"""
|
||||
Iterate over each car plate HSV boundary and apply mask to the given HSV image.
|
||||
|
||||
:param hsv: The HSV image to apply mask.
|
||||
:return: An iterator of CarPlateMask.
|
||||
"""
|
||||
|
||||
for boundary in CAR_PLATE_HSV_BOUNDARIES:
|
||||
# 以给定HSV范围检测符合该颜色的位置
|
||||
mask = cv.inRange(hsv, boundary.lower_bound, boundary.upper_bound)
|
||||
@@ -76,114 +98,66 @@ def _batchly_mask_car_plate(
|
||||
mask = cv.morphologyEx(mask, cv.MORPH_OPEN, kernel_open)
|
||||
|
||||
# Return value
|
||||
yield mask
|
||||
yield CarPlateMask(mask, boundary.need_revert)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CarPlateRegion:
|
||||
"""Car plate region result."""
|
||||
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
need_revert: bool
|
||||
"""是否对颜色取反,与CarPlateHsvBoundary中的同名字段含义一致"""
|
||||
|
||||
|
||||
MIN_AREA: float = 3000
|
||||
"""Minimum area for car plate region."""
|
||||
MIN_ASPECT_RATIO: float = 1.5
|
||||
"""Minimum aspect ratio for car plate region."""
|
||||
MAX_ASPECT_RATIO: float = 6.0
|
||||
"""Maximum aspect ratio for car plate region."""
|
||||
BEST_ASPECT_RATIO: float = 3.5
|
||||
"""Best aspect ratio for car plate region."""
|
||||
|
||||
|
||||
def _analyse_car_plate_connection(
|
||||
mask: cv.typing.MatLike,
|
||||
masks: typing.Iterator[CarPlateMask],
|
||||
) -> typing.Optional[CarPlateRegion]:
|
||||
# 连通域分析,筛选最符合车牌长宽比的区域
|
||||
num_labels, labels, stats, _ = cv.connectedComponentsWithStats(mask, connectivity=8)
|
||||
"""
|
||||
Analyse car plate connection in given masks.
|
||||
|
||||
:param masks: An iterator of CarPlateMask to analyse.
|
||||
:return: The car plate region if succeed, otherwise None.
|
||||
"""
|
||||
|
||||
best: typing.Optional[CarPlateRegion] = None
|
||||
best_score = 0
|
||||
|
||||
for i in range(1, num_labels):
|
||||
x, y, w, h, area = stats[i]
|
||||
# 检查面积
|
||||
if area < MIN_AREA:
|
||||
continue
|
||||
# 标准车牌宽高比约 3:1 ~ 5:1
|
||||
ratio = w / (h + 1e-5)
|
||||
if ratio >= MIN_ASPECT_RATIO and ratio <= MAX_ASPECT_RATIO:
|
||||
score = area * (1 - abs(ratio - BEST_ASPECT_RATIO) / BEST_ASPECT_RATIO)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best = CarPlateRegion(x, y, w, h)
|
||||
for mask in masks:
|
||||
# 连通域分析,筛选最符合车牌长宽比的区域
|
||||
num_labels, labels, stats, _ = cv.connectedComponentsWithStats(
|
||||
mask.mask, connectivity=8
|
||||
)
|
||||
|
||||
for i in range(1, num_labels):
|
||||
x, y, w, h, area = stats[i]
|
||||
# 检查面积
|
||||
if area < MIN_AREA:
|
||||
continue
|
||||
# 标准车牌宽高比约 3:1 ~ 5:1
|
||||
ratio = w / (h + 1e-5)
|
||||
if ratio >= MIN_ASPECT_RATIO and ratio <= MAX_ASPECT_RATIO:
|
||||
score = area * (1 - abs(ratio - BEST_ASPECT_RATIO) / BEST_ASPECT_RATIO)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best = CarPlateRegion(x, y, w, h, mask.need_revert)
|
||||
|
||||
return best
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerspectiveData:
|
||||
top_left: tuple[int, int]
|
||||
top_right: tuple[int, int]
|
||||
bottom_left: tuple[int, int]
|
||||
bottom_right: tuple[int, int]
|
||||
|
||||
new_width: int
|
||||
new_height: int
|
||||
|
||||
|
||||
def _extract_perspective_data(
|
||||
gray: cv.typing.MatLike,
|
||||
) -> typing.Optional[PerspectiveData]:
|
||||
""" """
|
||||
# Histogram balance to increase contrast
|
||||
hist_gray = cv.equalizeHist(gray)
|
||||
|
||||
# Apply Gaussian blur to reduce noise
|
||||
blurred = cv.GaussianBlur(hist_gray, (5, 5), 0)
|
||||
|
||||
# Edge detection using Canny
|
||||
edges = cv.Canny(blurred, 50, 150)
|
||||
|
||||
# Find contours
|
||||
contours, _ = cv.findContours(edges, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
|
||||
if not contours:
|
||||
return None
|
||||
# Find the largest one because all image is car plate
|
||||
max_area_contour = max(contours, key=lambda contour: cv.contourArea(contour))
|
||||
|
||||
# Approximate the contour
|
||||
peri = cv.arcLength(max_area_contour, True)
|
||||
approx = cv.approxPolyDP(max_area_contour, 0.02 * peri, True)
|
||||
if len(approx) != 4:
|
||||
return None
|
||||
|
||||
# Perspective transformation to get front view
|
||||
# Order points: top-left, top-right, bottom-right, bottom-left
|
||||
pts = approx.reshape(4, 2)
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
|
||||
# Sum and difference of coordinates to find corners
|
||||
s = pts.sum(axis=1)
|
||||
top_left = pts[np.argmin(s)] # Top-left has smallest sum
|
||||
bottom_right = pts[np.argmax(s)] # Bottom-right has largest sum
|
||||
|
||||
diff = np.diff(pts, axis=1)
|
||||
top_right = pts[np.argmin(diff)] # Top-right has smallest difference
|
||||
bottom_left = pts[np.argmax(diff)] # Bottom-left has largest difference
|
||||
|
||||
# Calculate width and height of new image
|
||||
width_a = np.linalg.norm(rect[0] - rect[1])
|
||||
width_b = np.linalg.norm(rect[2] - rect[3])
|
||||
max_width = max(int(width_a), int(width_b))
|
||||
|
||||
height_a = np.linalg.norm(rect[0] - rect[3])
|
||||
height_b = np.linalg.norm(rect[1] - rect[2])
|
||||
max_height = max(int(height_a), int(height_b))
|
||||
|
||||
# Return value
|
||||
return PerspectiveData(
|
||||
top_left, top_right, bottom_left, bottom_right, max_width, max_height
|
||||
)
|
||||
|
||||
|
||||
def extract_car_plate(img: cv.typing.MatLike) -> typing.Optional[cv.typing.MatLike]:
|
||||
"""
|
||||
Extract the car plate part from given image.
|
||||
@@ -191,20 +165,15 @@ def extract_car_plate(img: cv.typing.MatLike) -> typing.Optional[cv.typing.MatLi
|
||||
:param img: The image containing car plate in BGR format.
|
||||
:return: The image of binary car plate in U8 format if succeed, otherwise None.
|
||||
"""
|
||||
# 统一图片大小
|
||||
img = _uniform_car_plate(img)
|
||||
|
||||
# 转换到HSV空间
|
||||
hsv = cv.cvtColor(img, cv.COLOR_BGR2HSV)
|
||||
|
||||
# 利用车牌颜色在 HSV 空间定位车牌
|
||||
candidate: typing.Optional[CarPlateRegion] = None
|
||||
for mask in _batchly_mask_car_plate(hsv):
|
||||
# 连通域分析,筛选最符合车牌长宽比的区域作为车牌
|
||||
candidate = _analyse_car_plate_connection(mask)
|
||||
# 找到任意一个就退出
|
||||
if candidate is not None:
|
||||
break
|
||||
|
||||
masks = _batchly_mask_car_plate(hsv)
|
||||
candidate = _analyse_car_plate_connection(masks)
|
||||
if candidate is None:
|
||||
logging.error("Can not find any car plate.")
|
||||
return None
|
||||
@@ -232,39 +201,20 @@ def extract_car_plate(img: cv.typing.MatLike) -> typing.Optional[cv.typing.MatLi
|
||||
# Otsu 自动阈值,得到白字黑底,再取反 → 黑字白底
|
||||
_, binary_otsu = cv.threshold(blurred, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU)
|
||||
# 反转:字符变黑,背景变白
|
||||
binary = cv.bitwise_not(binary_otsu)
|
||||
if candidate.need_revert:
|
||||
binary = cv.bitwise_not(binary_otsu)
|
||||
else:
|
||||
binary = binary_otsu
|
||||
|
||||
# 去除小噪点(开运算)
|
||||
kernel_denoise = cv.getStructuringElement(cv.MORPH_RECT, (2, 2))
|
||||
binary = cv.morphologyEx(binary, cv.MORPH_OPEN, kernel_denoise)
|
||||
|
||||
# 尝试获取视角矫正数据
|
||||
perspective_data = _extract_perspective_data(gray)
|
||||
if perspective_data is None:
|
||||
logging.warning(f'Can not fetch perspective data. The output image has no perspective correction.')
|
||||
return binary
|
||||
|
||||
# 执行视角矫正
|
||||
perspective_src = np.array([
|
||||
list(perspective_data.top_left),
|
||||
list(perspective_data.top_right),
|
||||
list(perspective_data.bottom_right),
|
||||
list(perspective_data.bottom_left)
|
||||
], dtype="float32")
|
||||
perspective_dst = np.array([
|
||||
[0, 0],
|
||||
[perspective_data.new_width - 1, 0],
|
||||
[perspective_data.new_width - 1, perspective_data.new_height - 1],
|
||||
[0, perspective_data.new_height - 1]
|
||||
], dtype="float32")
|
||||
M = cv.getPerspectiveTransform(perspective_src, perspective_dst)
|
||||
warped = cv.warpPerspective(binary, M, (perspective_data.new_width, perspective_data.new_height))
|
||||
|
||||
return warped
|
||||
return binary
|
||||
# cv.imwrite('./plate_binary.png', binary)
|
||||
# print("二值化结果已保存: plate_binary.png")
|
||||
|
||||
# ── 4. 叠加边框轮廓(细化文字边缘,参考效果图)─────────────────────
|
||||
# 叠加边框轮廓(细化文字边缘,参考效果图)
|
||||
# Canny 边缘叠加让效果更接近参考图
|
||||
edges = cv.Canny(blurred, 40, 120)
|
||||
edges_inv = cv.bitwise_not(edges) # 边缘→黑色
|
||||
|
||||
Reference in New Issue
Block a user