1
0

update for change of exp2 and add exp3

This commit is contained in:
2025-11-30 16:24:32 +08:00
parent af890d899e
commit 48fcdfcc80
17 changed files with 859 additions and 124 deletions

114
exp2/modified/mnist.py Normal file
View File

@@ -0,0 +1,114 @@
from pathlib import Path
import numpy
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2 as tvtrans
from torchvision import datasets
import torch.nn.functional as F
class CNN(torch.nn.Module):
"""卷积神经网络模型"""
def __init__(self):
super(CNN, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3))
self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3))
self.flatten = torch.nn.Flatten()
# 28x28过第一轮卷积后变为26x26过第一轮池化后变为13x13
# 过第二轮卷积后变为11x11过第二轮池化后变为5x5
# 过第三轮卷积后变为3x3。
# 最后一轮卷积核个数为64。
self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
torch.nn.init.xavier_normal_(self.fc1.weight)
torch.nn.init.zeros_(self.fc1.bias)
self.fc2 = torch.nn.Linear(64, 10)
torch.nn.init.xavier_normal_(self.fc2.weight)
torch.nn.init.zeros_(self.fc2.bias)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
class MnistDataset(Dataset):
"""用于加载Mnist的自定义数据集"""
shape: int
transform: tvtrans.Transform
images_data: numpy.ndarray
labels_data: torch.Tensor
def __init__(self, images: numpy.ndarray, labels: numpy.ndarray, transform: tvtrans.Transform):
images_len: int = images.shape[0]
labels_len: int = labels.shape[0]
assert (images_len == labels_len)
self.shape = images_len
self.images_data = images
self.labels_data = torch.from_numpy(labels)
self.transform = transform
def __getitem__(self, index):
return self.transform(self.images_data[index]), self.labels_data[index]
def __len__(self):
return self.shape
class MnistDataSource:
"""用于读取MNIST数据的数据读取器"""
train_loader: DataLoader
test_loader: DataLoader
def __init__(self, batch_size: int):
dataset_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz'
dataset = numpy.load(dataset_path)
# 所有图片均为黑底白字
# 6万张训练图片60000x28x28。标签只有第一维。
train_images: numpy.ndarray = dataset['x_train']
train_labels: numpy.ndarray = dataset['y_train']
# 1万张测试图片10000x28x28。标签只有第一维。
test_images: numpy.ndarray = dataset['x_test']
test_labels: numpy.ndarray = dataset['y_test']
# 定义数据转换器
trans = tvtrans.Compose([
# 从uint8转换为float32并自动归一化到0-1区间
# tvtrans.ToTensor(),
tvtrans.ToImage(),
tvtrans.ToDtype(torch.float32, scale=True),
# 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度
#tvtrans.Lambda(lambda x: x.unsqueeze(-1))
# 这个特定的标准化参数 (0.1307, 0.3081) 是 MNIST 数据集的标准化参数这些数值是MNIST训练集的全局均值和标准差。
# 这种标准化有助于模型训练时的数值稳定性和收敛速度。
#tvtrans.Normalize((0.1307,), (0.3081,)),
])
# 创建数据集
train_dataset = MnistDataset(train_images, train_labels, transform=trans)
test_dataset = MnistDataset(test_images, test_labels, transform=trans)
# 赋值到自身
self.train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=False)
self.test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)

View File

@@ -1,7 +1,10 @@
from pathlib import Path from pathlib import Path
import sys import sys
import torch import torch
from train import CNN import numpy
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
from mnist import CNN
sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils import gpu_utils
@@ -36,7 +39,7 @@ class Predictor:
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth' file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
self.cnn.load_state_dict(torch.load(file_path)) self.cnn.load_state_dict(torch.load(file_path))
def predict(self, image: list[list[bool]]) -> PredictResult: def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult:
input = torch.Tensor(image).float().to(self.device) input = torch.Tensor(image).float().to(self.device)
assert(input.dim() == 2) assert(input.dim() == 2)
assert(input.size(0) == 28) assert(input.size(0) == 28)
@@ -51,4 +54,42 @@ class Predictor:
with torch.no_grad(): with torch.no_grad():
output = self.cnn(input) output = self.cnn(input)
return PredictResult(output) return PredictResult(output)
def predict_image(self, image: ImageFile.ImageFile) -> PredictResult:
# 确保图像为灰度图像然后转换为numpy数组。
# 注意这里的numpy数组是只读的所以要先拷贝一份
grayscale_image = image.convert('L')
numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True)
# 转换到Tensor设置dtype并传到GPU上
data = torch.from_numpy(numpy_data).float().to(self.device)
# 归一化到255又因为图像输入是白底黑字需要做转换。
data.div_(255.0).sub_(1).mul_(-1)
# 同理,挤出维度并预测
input = data.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
output = self.cnn(input)
return PredictResult(output)
def main():
predictor = Predictor()
# 遍历测试目录中的所有图片,并处理。
test_dir = Path(__file__).resolve().parent.parent / 'test_images'
for image_path in test_dir.glob('*.png'):
if image_path.is_file():
print(f'Predicting {image_path} ...')
image = Image.open(image_path)
rv = predictor.predict_image(image)
print(f'Predict digit: {rv.chosen_number()}')
plt.figure(f'Image - {image_path}')
plt.imshow(image)
plt.axis('on')
plt.title(f'Predict digit: {rv.chosen_number()}')
plt.show()
if __name__ == "__main__":
main()

View File

@@ -169,7 +169,7 @@ class SketchpadApp:
def execute(self): def execute(self):
"""执行按钮功能 - 将画板数据传递给后端""" """执行按钮功能 - 将画板数据传递给后端"""
prediction = self.predictor.predict(self.canvas_data) prediction = self.predictor.predict_sketchpad(self.canvas_data)
self.show_in_table(prediction) self.show_in_table(prediction)
def reset(self): def reset(self):

View File

@@ -1,145 +1,57 @@
from pathlib import Path from pathlib import Path
import sys import sys
import typing import typing
import numpy
import torch import torch
from torch.utils.data import DataLoader, Dataset import torchinfo
from torchvision.transforms import v2 as tvtrans import ignite.engine
import matplotlib.pyplot as plt import ignite.metrics
import torch.nn.functional as F from ignite.engine import Engine, Events
from ignite.handlers.tqdm_logger import ProgressBar
from mnist import CNN, MnistDataSource
sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
import gpu_utils import gpu_utils
class CNN(torch.nn.Module):
"""卷积神经网络模型"""
def __init__(self):
super(CNN, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3))
self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3))
self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3))
self.flatten = torch.nn.Flatten()
# 28x28过第一轮卷积后变为26x26过第一轮池化后变为13x13
# 过第二轮卷积后变为11x11过第二轮池化后变为5x5
# 过第三轮卷积后变为3x3。
# 最后一轮卷积核个数为64。
self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
self.fc2 = torch.nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = self.flatten(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
class MnistDataset(Dataset):
"""用于加载Mnist的自定义数据集"""
shape: int
transform: tvtrans.Transform
images_data: numpy.ndarray
labels_data: torch.Tensor
def __init__(self, images: numpy.ndarray, labels: numpy.ndarray, transform: tvtrans.Transform):
images_len: int = images.size(0)
labels_len: int = labels.size(0)
assert (images_len == labels_len)
self.shape = images_len
self.images_data = images
self.labels_data = torch.from_numpy(labels)
self.transform = transform
def __getitem__(self, index):
return self.transform(self.images_data[index]), self.labels_data[index]
def __len__(self):
return self.shape
class DataSource:
"""用于读取MNIST数据的数据读取器"""
train_data: DataLoader
test_data: DataLoader
def __init__(self, batch_size: int):
datasets_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz'
datasets = numpy.load(datasets_path)
# 所有图片均为黑底白字
# 6万张训练图片60000x28x28。标签只有第一维。
train_images = datasets['x_train']
train_labels = datasets['y_train']
# 1万张测试图片10000x28x28。标签只有第一维。
test_images = datasets['x_test']
test_labels = datasets['y_test']
# 定义数据转换器
trans = tvtrans.Compose([
# 从uint8转换为float32并自动归一化到0-1区间
# tvtrans.ToTensor(),
tvtrans.ToImage(),
tvtrans.ToDtype(torch.float32, scale=True),
# 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度
#tvtrans.Lambda(lambda x: x.unsqueeze(-1))
])
# 创建数据集
train_dataset = MnistDataset(train_images,
train_labels,
transform=trans)
test_dataset = MnistDataset(test_images, test_labels, transform=trans)
# 赋值到自身
self.train_data = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=False)
self.test_data = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
class Trainer: class Trainer:
N_EPOCH: typing.ClassVar[int] = 5 N_EPOCH: typing.ClassVar[int] = 5
N_BATCH_SIZE: typing.ClassVar[int] = 1000 N_BATCH_SIZE: typing.ClassVar[int] = 1000
device: torch.device device: torch.device
data_source: DataSource data_source: MnistDataSource
cnn: CNN model: CNN
def __init__(self): def __init__(self):
self.device = gpu_utils.get_gpu_device() self.device = gpu_utils.get_gpu_device()
self.data_source = DataSource(Trainer.N_BATCH_SIZE) self.data_source = MnistDataSource(Trainer.N_BATCH_SIZE)
self.cnn = CNN().to(self.device) self.model = CNN().to(self.device)
# 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
torchinfo.summary(self.model, (Trainer.N_BATCH_SIZE, 1, 28, 28))
def train(self): def train(self):
optimizer = torch.optim.Adam(self.cnn.parameters()) optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
# optimizer = torch.optim.AdamW(
# self.model.parameters(),
# lr=0.001, # 两者默认学习率都是 0.001
# betas=(0.9, 0.999), # 两者默认值相同
# eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon
# weight_decay=0.0, # 两者默认都是 0
# amsgrad=False # 两者默认都是 False
# )
loss_func = torch.nn.CrossEntropyLoss() loss_func = torch.nn.CrossEntropyLoss()
for epoch in range(Trainer.N_EPOCH): for epoch in range(Trainer.N_EPOCH):
self.cnn.train() self.model.train()
batch_images: torch.Tensor batch_images: torch.Tensor
batch_labels: torch.Tensor batch_labels: torch.Tensor
for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_data): for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_loader):
gpu_images = batch_images.to(self.device) gpu_images = batch_images.to(self.device)
gpu_labels = batch_labels.to(self.device) gpu_labels = batch_labels.to(self.device)
optimizer.zero_grad() prediction: torch.Tensor = self.model(gpu_images)
prediction: torch.Tensor = self.cnn(gpu_images)
loss: torch.Tensor = loss_func(prediction, gpu_labels) loss: torch.Tensor = loss_func(prediction, gpu_labels)
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@@ -151,20 +63,20 @@ class Trainer:
file_dir_path = Path(__file__).resolve().parent.parent / 'models' file_dir_path = Path(__file__).resolve().parent.parent / 'models'
file_dir_path.mkdir(parents=True, exist_ok=True) file_dir_path.mkdir(parents=True, exist_ok=True)
file_path = file_dir_path / 'cnn.pth' file_path = file_dir_path / 'cnn.pth'
torch.save(self.cnn.state_dict(), file_path) torch.save(self.model.state_dict(), file_path)
print(f'模型已保存至:{file_path}') print(f'模型已保存至:{file_path}')
def test(self): def test(self):
self.cnn.eval() self.model.eval()
correct_sum = 0 correct_sum = 0
total_sum = 0 total_sum = 0
with torch.no_grad(): with torch.no_grad():
for batch_images, batch_labels in self.data_source.test_data: for batch_images, batch_labels in self.data_source.test_loader:
gpu_images = batch_images.to(self.device) gpu_images = batch_images.to(self.device)
gpu_labels = batch_labels.to(self.device) gpu_labels = batch_labels.to(self.device)
possibilities: torch.Tensor = self.cnn(gpu_images) possibilities: torch.Tensor = self.model(gpu_images)
# 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比 # 输出出来是10个数字各自的可能性所以要选取最高可能性的那个对比
# 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。 # 在dim=1上找最大的那个就选那个。dim=0是批次所以不管他。
_, prediction = possibilities.max(1) _, prediction = possibilities.max(1)
@@ -175,13 +87,105 @@ class Trainer:
test_acc = 100. * correct_sum / total_sum test_acc = 100. * correct_sum / total_sum
print(f"准确率: {test_acc:.4f}%,共测试了{total_sum}张图片") print(f"准确率: {test_acc:.4f}%,共测试了{total_sum}张图片")
def main(): def main():
trainer = Trainer() trainer = Trainer()
trainer.train() trainer.train()
trainer.save() trainer.save()
trainer.test() trainer.test()
# N_EPOCH: int = 5
# N_BATCH_SIZE: int = 1000
# N_LOG_INTERVAL: int = 10
# class Trainer:
# device: torch.device
# data_source: MnistDataSource
# model: CNN
# trainer: Engine
# train_evaluator: Engine
# test_evaluator: Engine
# def __init__(self):
# self.device = gpu_utils.get_gpu_device()
# self.model = CNN().to(self.device)
# self.data_source = MnistDataSource(batch_size=N_BATCH_SIZE)
# # 展示模型结构。批次为指定批次数量通道只有一个灰度通道大小28x28。
# torchinfo.summary(self.model, (N_BATCH_SIZE, 1, 28, 28))
# #optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
# optimizer = torch.optim.AdamW(
# self.model.parameters(),
# lr=0.001, # 两者默认学习率都是 0.001
# betas=(0.9, 0.999), # 两者默认值相同
# eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon
# weight_decay=0.0, # 两者默认都是 0
# amsgrad=False # 两者默认都是 False
# )
# criterion = torch.nn.CrossEntropyLoss()
# self.trainer = ignite.engine.create_supervised_trainer(
# self.model, optimizer, criterion, self.device
# )
# eval_metrics = {
# "accuracy": ignite.metrics.Accuracy(device=self.device),
# "loss": ignite.metrics.Loss(criterion, device=self.device)
# }
# self.train_evaluator = ignite.engine.create_supervised_evaluator(
# self.model, metrics=eval_metrics, device=self.device)
# self.test_evaluator = ignite.engine.create_supervised_evaluator(
# self.model, metrics=eval_metrics, device=self.device)
# self.trainer.add_event_handler(
# Events.ITERATION_COMPLETED(every=N_LOG_INTERVAL),
# lambda engine: self.log_intrain_loss(engine)
# )
# self.trainer.add_event_handler(
# Events.EPOCH_COMPLETED,
# lambda trainer: self.log_train_results(trainer)
# )
# self.trainer.add_event_handler(
# Events.COMPLETED,
# lambda _: self.log_test_results()
# )
# self.trainer.add_event_handler(
# Events.COMPLETED,
# lambda _: self.save_model()
# )
# progressbar = ProgressBar()
# progressbar.attach(self.trainer)
# def run(self):
# self.trainer.run(self.data_source.train_loader, max_epochs=N_EPOCH)
# def log_intrain_loss(self, engine: Engine):
# print(f"Epoch: {engine.state.epoch}, Loss: {engine.state.output:.4f}\r", end="")
# def log_train_results(self, trainer: Engine):
# self.train_evaluator.run(self.data_source.train_loader)
# metrics = self.train_evaluator.state.metrics
# print()
# print(f"Training - Epoch: {trainer.state.epoch}, Avg Accuracy: {metrics['accuracy']:.4f}, Avg Loss: {metrics['loss']:.4f}")
# def log_test_results(self):
# self.test_evaluator.run(self.data_source.test_loader)
# metrics = self.test_evaluator.state.metrics
# print(f"Test - Avg Accuracy: {metrics['accuracy']:.4f} Avg Loss: {metrics['loss']:.4f}")
# def save_model(self):
# file_dir_path = Path(__file__).resolve().parent.parent / 'models'
# file_dir_path.mkdir(parents=True, exist_ok=True)
# file_path = file_dir_path / 'cnn.pth'
# torch.save(self.model.state_dict(), file_path)
# print(f'Model was saved into: {file_path}')
# def main():
# trainer = Trainer()
# trainer.run()
if __name__ == "__main__": if __name__ == "__main__":
gpu_utils.print_gpu_availability() gpu_utils.print_gpu_availability()

2
exp2/test_images/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
# Ignore all test images
*.png

2
exp3/datasets/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
# Ignore datasets and processed datasets
*.txt

2
exp3/models/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
# Ignore every saved model files
*.pth

171
exp3/modified/dataset.py Normal file
View File

@@ -0,0 +1,171 @@
from pathlib import Path
import typing
import pickle
from collections import Counter
import numpy
class Tokenizer:
"""分词器"""
token_dict: dict[str, int]
"""词->编号的映射"""
token_dict_rev: dict[int, str]
"""编号->词的映射"""
vocab_size: int
"""词汇表大小"""
def __init__(self, token_dict: dict[str, int]):
self.token_dict = token_dict
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
self.vocab_size = len(self.token_dict)
def id_to_token(self, token_id: int) -> str:
"""
给定一个编号,查找词汇表中对应的词。
:param token_id: 带查找词的编号
:return: 编号对应的词
"""
return self.token_dict_rev[token_id]
def token_to_id(self, token: str):
"""
给定一个词,查找它在词汇表中的编号。
未找到则返回低频词[UNK]的编号。
:param token: 带查找编号的词
:return: 词的编号
"""
return self.token_dict.get(token, self.token_dict['[UNK]'])
def encode(self, tokens: str) -> list[int]:
"""
给定一个字符串s在头尾分别加上标记开始和结束的特殊字符并将它转成对应的编号序列
:param tokens: 待编码字符串
:return: 编号序列
"""
# 加上开始标记
token_ids: list[int] = [self.token_to_id('[CLS]'), ]
# 加入字符串编号序列
for token in tokens:
token_ids.append(self.token_to_id(token))
# 加上结束标记
token_ids.append(self.token_to_id('[SEP]'))
return token_ids
def decode(self, token_ids: typing.Iterable[int]) -> str:
"""
给定一个编号序列,将它解码成字符串
:param token_ids: 待解码的编号序列
:return: 解码出的字符串
"""
# 起止标记字符特殊处理
spec_tokens = {'[CLS]', '[SEP]'}
# 保存解码出的字符的list
tokens: list[str] = []
for token_id in token_ids:
token = self.id_to_token(token_id)
if token in spec_tokens:
continue
tokens.append(token)
# 拼接字符串
return ''.join(tokens)
class PoetryDataset:
"""古诗词数据集加载器"""
BAD_WORDS: typing.ClassVar[list[str]] = ['', '', '(', ')', '__', '', '', '', '', '[', ']']
"""禁用词,包含如下字符的唐诗将被忽略"""
MAX_SEG_LEN: typing.ClassVar[int] = 64
"""句子最大长度"""
MIN_WORD_FREQ: typing.ClassVar[int] = 8
"""最小词频"""
tokenizer: Tokenizer
"""分词器"""
poetry: list[str]
"""古诗词数据集"""
def __init__(self, force_reclean: bool = False):
# 加载古诗,然后统计词频构建分词器
self.poetry = self.load_poetry(force_reclean)
self.tokenizer = self.build_tokenizer(self.poetry)
def load_poetry(self, force_reclean: bool = False) -> list[str]:
"""加载古诗词数据集"""
if force_reclean or (not self.get_clean_dataset_path().is_file()):
return self.load_poetry_from_dirty()
else:
return self.load_poetry_from_clean()
def load_poetry_from_clean(self) -> list[str]:
"""直接读取清洗后的数据"""
with open(self.get_clean_dataset_path(), 'rb') as f:
return pickle.load(f)
def load_poetry_from_dirty(self) -> list[str]:
"""从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据"""
with open(self.get_dirty_dataset_path(), 'r', encoding='utf-8') as f:
lines = f.readlines()
# 将冒号统一成相同格式
lines = [line.replace('', ':') for line in lines]
# 数据集列表
poetry: list[str] = []
# 逐行处理读取到的数据
for line in lines:
# 删除空白字符
line = line.strip()
# 有且只能有一个冒号用来分割标题
if line.count(':') != 1: continue
# 获取后半部分
_, last_part = line.split(':')
# 长度不能超过最大长度
if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2:
continue
# 不能包含禁止词
for bad_word in PoetryDataset.BAD_WORDS:
if bad_word in last_part:
break
else:
# 如果循环正常结束就表明没有bad words推入数据列表
poetry.append(last_part)
# 数据清理完毕
# 写入干净数据
with open(self.get_clean_dataset_path(), 'wb') as f:
pickle.dump(poetry, f)
# 返回结果
return poetry
def get_clean_dataset_path(self) -> Path:
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
def get_dirty_dataset_path(self) -> Path:
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
def build_tokenizer(self, poetry: list[str]) -> Tokenizer:
"""统计词频,并构建分词器"""
# 统计词频
counter: Counter[str] = Counter()
for line in poetry:
counter.update(line)
# 过滤掉低频词
tokens = ((token, count) for token, count in counter.items() if count >= PoetryDataset.MIN_WORD_FREQ)
# 按词频排序
tokens = sorted(tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
tokens = list(token for token, _ in tokens)
# 将特殊词和数据集中的词拼接起来
tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(tokens, range(len(tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 直接返回,此处无需混洗数据
return tokenizer

12
exp3/modified/settings.py Normal file
View File

@@ -0,0 +1,12 @@
from pathlib import Path
BATCH_SIZE: int = 16
"""训练的batch size"""
def get_saved_model_path() -> Path:
"""
获取训练完毕的模型进行保存的路径。
:return: 模型参数保存的路径。
"""
return Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'

0
exp3/modified/utils.py Normal file
View File

199
exp3/source/dataset.py Normal file
View File

@@ -0,0 +1,199 @@
#ANLI College of Artificial Intelligence
from collections import Counter
import math
import numpy as np
import tensorflow as tf
import settings
class Tokenizer:
"""
分词器
"""
def __init__(self, token_dict):
# 词->编号的映射
self.token_dict = token_dict
# 编号->词的映射
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
# 词汇表大小
self.vocab_size = len(self.token_dict)
def id_to_token(self, token_id):
"""
给定一个编号,查找词汇表中对应的词
:param token_id: 带查找词的编号
:return: 编号对应的词
"""
return self.token_dict_rev[token_id]
def token_to_id(self, token):
"""
给定一个词,查找它在词汇表中的编号
未找到则返回低频词[UNK]的编号
:param token: 带查找编号的词
:return: 词的编号
"""
return self.token_dict.get(token, self.token_dict['[UNK]'])
def encode(self, tokens):
"""
给定一个字符串s在头尾分别加上标记开始和结束的特殊字符并将它转成对应的编号序列
:param tokens: 待编码字符串
:return: 编号序列
"""
# 加上开始标记
token_ids = [self.token_to_id('[CLS]'), ]
# 加入字符串编号序列
for token in tokens:
token_ids.append(self.token_to_id(token))
# 加上结束标记
token_ids.append(self.token_to_id('[SEP]'))
return token_ids
def decode(self, token_ids):
"""
给定一个编号序列,将它解码成字符串
:param token_ids: 待解码的编号序列
:return: 解码出的字符串
"""
# 起止标记字符特殊处理
spec_tokens = {'[CLS]', '[SEP]'}
# 保存解码出的字符的list
tokens = []
for token_id in token_ids:
token = self.id_to_token(token_id)
if token in spec_tokens:
continue
tokens.append(token)
# 拼接字符串
return ''.join(tokens)
# 禁用词
disallowed_words = settings.DISALLOWED_WORDS
# 句子最大长度
max_len = settings.MAX_LEN
# 最小词频
min_word_frequency = settings.MIN_WORD_FREQUENCY
# mini batch 大小
batch_size = settings.BATCH_SIZE
# 加载数据集
with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 将冒号统一成相同格式
lines = [line.replace('', ':') for line in lines]
# 数据集列表
poetry = []
# 逐行处理读取到的数据
for line in lines:
# 有且只能有一个冒号用来分割标题
if line.count(':') != 1:
continue
# 后半部分不能包含禁止词
__, last_part = line.split(':')
ignore_flag = False
for dis_word in disallowed_words:
if dis_word in last_part:
ignore_flag = True
break
if ignore_flag:
continue
# 长度不能超过最大长度
if len(last_part) > max_len - 2:
continue
poetry.append(last_part.replace('\n', ''))
# 统计词频
counter = Counter()
for line in poetry:
counter.update(line)
# 过滤掉低频词
_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency]
# 按词频排序
_tokens = sorted(_tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
_tokens = [token for token, count in _tokens]
# 将特殊词和数据集中的词拼接起来
_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(_tokens, range(len(_tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 混洗数据
np.random.shuffle(poetry)
class PoetryDataGenerator:
"""
古诗数据集生成器
"""
def __init__(self, data, random=False):
# 数据集
self.data = data
# batch size
self.batch_size = batch_size
# 每个epoch迭代的步数
self.steps = int(math.floor(len(self.data) / self.batch_size))
# 每个epoch开始时是否随机混洗
self.random = random
def sequence_padding(self, data, length=None, padding=None):
"""
将给定数据填充到相同长度
:param data: 待填充数据
:param length: 填充后的长度不传递此参数则使用data中的最大长度
:param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号
:return: 填充后的数据
"""
# 计算填充长度
if length is None:
length = max(map(len, data))
# 计算填充数据
if padding is None:
padding = tokenizer.token_to_id('[PAD]')
# 开始填充
outputs = []
for line in data:
padding_length = length - len(line)
# 不足就进行填充
if padding_length > 0:
outputs.append(np.concatenate([line, [padding] * padding_length]))
# 超过就进行截断
else:
outputs.append(line[:length])
return np.array(outputs)
def __len__(self):
return self.steps
def __iter__(self):
total = len(self.data)
# 是否随机混洗
if self.random:
np.random.shuffle(self.data)
# 迭代一个epoch每次yield一个batch
for start in range(0, total, self.batch_size):
end = min(start + self.batch_size, total)
batch_data = []
# 逐一对古诗进行编码
for single_data in self.data[start:end]:
batch_data.append(tokenizer.encode(single_data))
# 填充为相同长度
batch_data = self.sequence_padding(batch_data)
# yield x,y
yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size)
del batch_data
def for_fit(self):
"""
创建一个生成器,用于训练
"""
# 死循环当数据训练一个epoch之后重新迭代数据
while True:
# 委托生成器
yield from self.__iter__()

16
exp3/source/eval.py Normal file
View File

@@ -0,0 +1,16 @@
#ANLI College of Artificial Intelligence
import tensorflow as tf
from dataset import tokenizer
import settings
import utils
# 加载训练好的模型
model = tf.keras.models.load_model(settings.BEST_MODEL_PATH)
# 随机生成一首诗
print(utils.generate_random_poetry(tokenizer, model))
# 给出部分信息的情况下,随机生成剩余部分
print(utils.generate_random_poetry(tokenizer, model, s='床前明月光,'))
# 生成藏头诗
print(utils.generate_acrostic(tokenizer, model, head='好好学习天天向上'))

19
exp3/source/settings.py Normal file
View File

@@ -0,0 +1,19 @@
#ANLI College of Artificial Intelligence
# 禁用词,包含如下字符的唐诗将被忽略
DISALLOWED_WORDS = ['', '', '(', ')', '__', '', '', '', '', '[', ']']
# 句子最大长度
MAX_LEN = 64
# 最小词频
MIN_WORD_FREQUENCY = 8
# 训练的batch size
BATCH_SIZE = 16
# 数据集路径
DATASET_PATH = './poetry.txt'
# 每个epoch训练完成后随机生成SHOW_NUM首古诗作为展示
SHOW_NUM = 5
# 共训练多少个epoch
TRAIN_EPOCHS = 10
# 最佳权重保存路径
BEST_MODEL_PATH = './best_model.h5'

35
exp3/source/train.py Normal file
View File

@@ -0,0 +1,35 @@
import tensorflow as tf
from dataset import PoetryDataGenerator, tokenizer, poetry
import settings
import utils
model = tf.keras.Sequential([
tf.keras.layers.Input((None,)),
tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
])
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)
class Evaluate(tf.keras.callbacks.Callback):
def __init__(self):
super().__init__()
self.lowest = 1e10
def on_epoch_end(self, epoch, logs=None):
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save(settings.BEST_MODEL_PATH)
print()
for i in range(settings.SHOW_NUM):
print(utils.generate_random_poetry(tokenizer, model))
data_generator = PoetryDataGenerator(poetry, random=False)
model.fit_generator(data_generator.for_fit(),
steps_per_epoch=data_generator.steps,
epochs=settings.TRAIN_EPOCHS,
callbacks=[Evaluate()])

86
exp3/source/utils.py Normal file
View File

@@ -0,0 +1,86 @@
import numpy as np
import settings
def generate_random_poetry(tokenizer, model, s=''):
"""
随机生成一首诗
:param tokenizer: 分词器
:param model: 用于生成古诗的模型
:param s: 用于生成古诗的起始字符串,默认为空串
:return: 一个字符串,表示一首古诗
"""
# 将初始字符串转成token
token_ids = tokenizer.encode(s)
# 去掉结束标记[SEP]
token_ids = token_ids[:-1]
while len(token_ids) < settings.MAX_LEN:
# 进行预测只保留第一个样例我们输入的样例数只有1的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
output = model(np.array([token_ids, ], dtype=np.int32))
_probas = output.numpy()[0, -1, 3:]
del output
# print(_probas)
# 按照出现概率对所有token倒序排列
p_args = _probas.argsort()[::-1][:100]
# 排列后的概率顺序
p = _probas[p_args]
# 先对概率归一
p = p / sum(p)
# 再按照预测出的概率,随机选择一个词作为预测结果
target_index = np.random.choice(len(p), p=p)
target = p_args[target_index] + 3
# 保存
token_ids.append(target)
if target == 3:
break
return tokenizer.decode(token_ids)
def generate_acrostic(tokenizer, model, head):
"""
随机生成一首藏头诗
:param tokenizer: 分词器
:param model: 用于生成古诗的模型
:param head: 藏头诗的头
:return: 一个字符串,表示一首古诗
"""
# 使用空串初始化token_ids加入[CLS]
token_ids = tokenizer.encode('')
token_ids = token_ids[:-1]
# 标点符号,这里简单的只把逗号和句号作为标点
punctuations = ['', '']
punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
# 缓存生成的诗的list
poetry = []
# 对于藏头诗中的每一个字,都生成一个短句
for ch in head:
# 先记录下这个字
poetry.append(ch)
# 将藏头诗的字符转成token id
token_id = tokenizer.token_to_id(ch)
# 加入到列表中去
token_ids.append(token_id)
# 开始生成一个短句
while True:
# 进行预测只保留第一个样例我们输入的样例数只有1的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布
output = model(np.array([token_ids, ], dtype=np.int32))
_probas = output.numpy()[0, -1, 3:]
del output
# 按照出现概率对所有token倒序排列
p_args = _probas.argsort()[::-1][:100]
# 排列后的概率顺序
p = _probas[p_args]
# 先对概率归一
p = p / sum(p)
# 再按照预测出的概率,随机选择一个词作为预测结果
target_index = np.random.choice(len(p), p=p)
target = p_args[target_index] + 3
# 保存
token_ids.append(target)
# 只有不是特殊字符时才保存到poetry里面去
if target > 3:
poetry.append(tokenizer.id_to_token(target))
if target in punctuation_ids:
break
return ''.join(poetry)

View File

@@ -8,8 +8,11 @@ dependencies = [
"datasets>=4.3.0", "datasets>=4.3.0",
"matplotlib>=3.10.7", "matplotlib>=3.10.7",
"numpy>=2.3.4", "numpy>=2.3.4",
"torch>=2.9.0", "pillow>=12.0.0",
"torchvision>=0.24.0", "pytorch-ignite>=0.5.3",
"torch>=2.9.0",
"torchinfo>=1.8.0",
"torchvision>=0.24.0",
] ]
[tool.uv.sources] [tool.uv.sources]

29
uv.lock generated
View File

@@ -393,8 +393,11 @@ dependencies = [
{ name = "datasets" }, { name = "datasets" },
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "numpy" }, { name = "numpy" },
{ name = "pillow" },
{ name = "pytorch-ignite" },
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.9.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "torch", version = "2.9.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "torchinfo" },
{ name = "torchvision", version = "0.24.0", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "torchvision", version = "0.24.0", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" },
{ name = "torchvision", version = "0.24.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, { name = "torchvision", version = "0.24.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torchvision", version = "0.24.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, { name = "torchvision", version = "0.24.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" },
@@ -405,8 +408,11 @@ requires-dist = [
{ name = "datasets", specifier = ">=4.3.0" }, { name = "datasets", specifier = ">=4.3.0" },
{ name = "matplotlib", specifier = ">=3.10.7" }, { name = "matplotlib", specifier = ">=3.10.7" },
{ name = "numpy", specifier = ">=2.3.4" }, { name = "numpy", specifier = ">=2.3.4" },
{ name = "pillow", specifier = ">=12.0.0" },
{ name = "pytorch-ignite", specifier = ">=0.5.3" },
{ name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.9.0" }, { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.9.0" },
{ name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.9.0", index = "https://download.pytorch.org/whl/cu126" }, { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.9.0", index = "https://download.pytorch.org/whl/cu126" },
{ name = "torchinfo", specifier = ">=1.8.0" },
{ name = "torchvision", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=0.24.0" }, { name = "torchvision", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=0.24.0" },
{ name = "torchvision", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=0.24.0", index = "https://download.pytorch.org/whl/cu126" }, { name = "torchvision", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=0.24.0", index = "https://download.pytorch.org/whl/cu126" },
] ]
@@ -1639,6 +1645,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
] ]
[[package]]
name = "pytorch-ignite"
version = "0.5.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "packaging" },
{ name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.9.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5a/e5/7fe880b24de30b4eadc8d997ea8d3c4a8f507b1a34dcdced08d88f665ee3/pytorch_ignite-0.5.3.tar.gz", hash = "sha256:75c645f02fea66cc80c1998ade3f8402e0e6b6d73f3f4ad727c171f6e93874f4", size = 7506607, upload-time = "2025-10-16T00:42:05.142Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d2/ea/f6d5ee7433a5a1c1e4746e2a4e9a222eab545fdbe04b66754ffdab479ee8/pytorch_ignite-0.5.3-py3-none-any.whl", hash = "sha256:4ced7539c690a3b6f3116da7878389954dff787c33669f83b38221f3746bc63e", size = 343802, upload-time = "2025-10-16T00:41:55.738Z" },
]
[[package]] [[package]]
name = "pytz" name = "pytz"
version = "2025.2" version = "2025.2"
@@ -1848,6 +1868,15 @@ wheels = [
{ url = "https://download.pytorch.org/whl/cu126/torch-2.9.0%2Bcu126-cp314-cp314t-win_amd64.whl", hash = "sha256:d8fdfc45ba30cb5c23971b35ae72c6fe246596022b574bd37dd0c775958f70b1" }, { url = "https://download.pytorch.org/whl/cu126/torch-2.9.0%2Bcu126-cp314-cp314t-win_amd64.whl", hash = "sha256:d8fdfc45ba30cb5c23971b35ae72c6fe246596022b574bd37dd0c775958f70b1" },
] ]
[[package]]
name = "torchinfo"
version = "1.8.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/53/d9/2b811d1c0812e9ef23e6cf2dbe022becbe6c5ab065e33fd80ee05c0cd996/torchinfo-1.8.0.tar.gz", hash = "sha256:72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9", size = 25880, upload-time = "2023-05-14T19:23:26.377Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/72/25/973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1/torchinfo-1.8.0-py3-none-any.whl", hash = "sha256:2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46", size = 23377, upload-time = "2023-05-14T19:23:24.141Z" },
]
[[package]] [[package]]
name = "torchvision" name = "torchvision"
version = "0.24.0" version = "0.24.0"