From 48fcdfcc809dfd2527f14772eb4c1690371f01de Mon Sep 17 00:00:00 2001 From: yyc12345 Date: Sun, 30 Nov 2025 16:24:32 +0800 Subject: [PATCH] update for change of exp2 and add exp3 --- exp2/modified/mnist.py | 114 +++++++++++++++++ exp2/modified/predict.py | 47 ++++++- exp2/modified/sketchpad.py | 2 +- exp2/modified/train.py | 240 ++++++++++++++++++------------------ exp2/test_images/.gitignore | 2 + exp3/datasets/.gitignore | 2 + exp3/models/.gitignore | 2 + exp3/modified/dataset.py | 171 +++++++++++++++++++++++++ exp3/modified/settings.py | 12 ++ exp3/modified/utils.py | 0 exp3/source/dataset.py | 199 ++++++++++++++++++++++++++++++ exp3/source/eval.py | 16 +++ exp3/source/settings.py | 19 +++ exp3/source/train.py | 35 ++++++ exp3/source/utils.py | 86 +++++++++++++ pyproject.toml | 7 +- uv.lock | 29 +++++ 17 files changed, 859 insertions(+), 124 deletions(-) create mode 100644 exp2/modified/mnist.py create mode 100644 exp2/test_images/.gitignore create mode 100644 exp3/datasets/.gitignore create mode 100644 exp3/models/.gitignore create mode 100644 exp3/modified/dataset.py create mode 100644 exp3/modified/settings.py create mode 100644 exp3/modified/utils.py create mode 100644 exp3/source/dataset.py create mode 100644 exp3/source/eval.py create mode 100644 exp3/source/settings.py create mode 100644 exp3/source/train.py create mode 100644 exp3/source/utils.py diff --git a/exp2/modified/mnist.py b/exp2/modified/mnist.py new file mode 100644 index 0000000..b22993d --- /dev/null +++ b/exp2/modified/mnist.py @@ -0,0 +1,114 @@ +from pathlib import Path +import numpy +import torch +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import v2 as tvtrans +from torchvision import datasets +import torch.nn.functional as F + + +class CNN(torch.nn.Module): + """卷积神经网络模型""" + + def __init__(self): + super(CNN, self).__init__() + + self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3)) + self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2)) + self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3)) + self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2)) + self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3)) + self.flatten = torch.nn.Flatten() + # 28x28过第一轮卷积后变为26x26,过第一轮池化后变为13x13 + # 过第二轮卷积后变为11x11,过第二轮池化后变为5x5 + # 过第三轮卷积后变为3x3。 + # 最后一轮卷积核个数为64。 + self.fc1 = torch.nn.Linear(64 * 3 * 3, 64) + torch.nn.init.xavier_normal_(self.fc1.weight) + torch.nn.init.zeros_(self.fc1.bias) + self.fc2 = torch.nn.Linear(64, 10) + torch.nn.init.xavier_normal_(self.fc2.weight) + torch.nn.init.zeros_(self.fc2.bias) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = self.pool1(x) + x = F.relu(self.conv2(x)) + x = self.pool2(x) + x = F.relu(self.conv3(x)) + x = self.flatten(x) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.softmax(x, dim=1) + + +class MnistDataset(Dataset): + """用于加载Mnist的自定义数据集""" + + shape: int + transform: tvtrans.Transform + images_data: numpy.ndarray + labels_data: torch.Tensor + + def __init__(self, images: numpy.ndarray, labels: numpy.ndarray, transform: tvtrans.Transform): + images_len: int = images.shape[0] + labels_len: int = labels.shape[0] + assert (images_len == labels_len) + self.shape = images_len + + self.images_data = images + self.labels_data = torch.from_numpy(labels) + self.transform = transform + + def __getitem__(self, index): + return self.transform(self.images_data[index]), self.labels_data[index] + + def __len__(self): + return self.shape + + +class MnistDataSource: + """用于读取MNIST数据的数据读取器""" + + train_loader: DataLoader + test_loader: DataLoader + + def __init__(self, batch_size: int): + dataset_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz' + dataset = numpy.load(dataset_path) + + # 所有图片均为黑底白字 + # 6万张训练图片:60000x28x28。标签只有第一维。 + train_images: numpy.ndarray = dataset['x_train'] + train_labels: numpy.ndarray = dataset['y_train'] + # 1万张测试图片:10000x28x28。标签只有第一维。 + test_images: numpy.ndarray = dataset['x_test'] + test_labels: numpy.ndarray = dataset['y_test'] + + # 定义数据转换器 + trans = tvtrans.Compose([ + # 从uint8转换为float32并自动归一化到0-1区间 + # tvtrans.ToTensor(), + tvtrans.ToImage(), + tvtrans.ToDtype(torch.float32, scale=True), + + # 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度 + #tvtrans.Lambda(lambda x: x.unsqueeze(-1)) + + # 这个特定的标准化参数 (0.1307, 0.3081) 是 MNIST 数据集的标准化参数,这些数值是MNIST训练集的全局均值和标准差。 + # 这种标准化有助于模型训练时的数值稳定性和收敛速度。 + #tvtrans.Normalize((0.1307,), (0.3081,)), + ]) + + # 创建数据集 + train_dataset = MnistDataset(train_images, train_labels, transform=trans) + test_dataset = MnistDataset(test_images, test_labels, transform=trans) + + + # 赋值到自身 + self.train_loader = DataLoader(dataset=train_dataset, + batch_size=batch_size, + shuffle=False) + self.test_loader = DataLoader(dataset=test_dataset, + batch_size=batch_size, + shuffle=False) diff --git a/exp2/modified/predict.py b/exp2/modified/predict.py index 2f4ba9e..3027409 100644 --- a/exp2/modified/predict.py +++ b/exp2/modified/predict.py @@ -1,7 +1,10 @@ from pathlib import Path import sys import torch -from train import CNN +import numpy +from PIL import Image, ImageFile +import matplotlib.pyplot as plt +from mnist import CNN sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) import gpu_utils @@ -36,7 +39,7 @@ class Predictor: file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth' self.cnn.load_state_dict(torch.load(file_path)) - def predict(self, image: list[list[bool]]) -> PredictResult: + def predict_sketchpad(self, image: list[list[bool]]) -> PredictResult: input = torch.Tensor(image).float().to(self.device) assert(input.dim() == 2) assert(input.size(0) == 28) @@ -51,4 +54,42 @@ class Predictor: with torch.no_grad(): output = self.cnn(input) return PredictResult(output) - \ No newline at end of file + + def predict_image(self, image: ImageFile.ImageFile) -> PredictResult: + # 确保图像为灰度图像,然后转换为numpy数组。 + # 注意这里的numpy数组是只读的,所以要先拷贝一份 + grayscale_image = image.convert('L') + numpy_data = numpy.reshape(grayscale_image, (28, 28), copy=True) + # 转换到Tensor,设置dtype并传到GPU上 + data = torch.from_numpy(numpy_data).float().to(self.device) + # 归一化到255,又因为图像输入是白底黑字,需要做转换。 + data.div_(255.0).sub_(1).mul_(-1) + + # 同理,挤出维度并预测 + input = data.unsqueeze(0).unsqueeze(0) + with torch.no_grad(): + output = self.cnn(input) + return PredictResult(output) + +def main(): + predictor = Predictor() + + # 遍历测试目录中的所有图片,并处理。 + test_dir = Path(__file__).resolve().parent.parent / 'test_images' + for image_path in test_dir.glob('*.png'): + if image_path.is_file(): + print(f'Predicting {image_path} ...') + image = Image.open(image_path) + rv = predictor.predict_image(image) + + print(f'Predict digit: {rv.chosen_number()}') + plt.figure(f'Image - {image_path}') + plt.imshow(image) + plt.axis('on') + plt.title(f'Predict digit: {rv.chosen_number()}') + plt.show() + + +if __name__ == "__main__": + main() + diff --git a/exp2/modified/sketchpad.py b/exp2/modified/sketchpad.py index fff9758..5d845c8 100644 --- a/exp2/modified/sketchpad.py +++ b/exp2/modified/sketchpad.py @@ -169,7 +169,7 @@ class SketchpadApp: def execute(self): """执行按钮功能 - 将画板数据传递给后端""" - prediction = self.predictor.predict(self.canvas_data) + prediction = self.predictor.predict_sketchpad(self.canvas_data) self.show_in_table(prediction) def reset(self): diff --git a/exp2/modified/train.py b/exp2/modified/train.py index 12ab63d..0113ed0 100644 --- a/exp2/modified/train.py +++ b/exp2/modified/train.py @@ -1,145 +1,57 @@ from pathlib import Path import sys import typing -import numpy import torch -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import v2 as tvtrans -import matplotlib.pyplot as plt -import torch.nn.functional as F +import torchinfo +import ignite.engine +import ignite.metrics +from ignite.engine import Engine, Events +from ignite.handlers.tqdm_logger import ProgressBar +from mnist import CNN, MnistDataSource sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) import gpu_utils -class CNN(torch.nn.Module): - """卷积神经网络模型""" - - def __init__(self): - super(CNN, self).__init__() - - self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3)) - self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2)) - self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3)) - self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2)) - self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3, 3)) - self.flatten = torch.nn.Flatten() - # 28x28过第一轮卷积后变为26x26,过第一轮池化后变为13x13 - # 过第二轮卷积后变为11x11,过第二轮池化后变为5x5 - # 过第三轮卷积后变为3x3。 - # 最后一轮卷积核个数为64。 - self.fc1 = torch.nn.Linear(64 * 3 * 3, 64) - self.fc2 = torch.nn.Linear(64, 10) - - def forward(self, x): - x = F.relu(self.conv1(x)) - x = self.pool1(x) - x = F.relu(self.conv2(x)) - x = self.pool2(x) - x = F.relu(self.conv3(x)) - x = self.flatten(x) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return F.softmax(x, dim=1) - - -class MnistDataset(Dataset): - """用于加载Mnist的自定义数据集""" - - shape: int - transform: tvtrans.Transform - images_data: numpy.ndarray - labels_data: torch.Tensor - - def __init__(self, images: numpy.ndarray, labels: numpy.ndarray, transform: tvtrans.Transform): - images_len: int = images.size(0) - labels_len: int = labels.size(0) - assert (images_len == labels_len) - self.shape = images_len - - self.images_data = images - self.labels_data = torch.from_numpy(labels) - self.transform = transform - - def __getitem__(self, index): - return self.transform(self.images_data[index]), self.labels_data[index] - - def __len__(self): - return self.shape - - -class DataSource: - """用于读取MNIST数据的数据读取器""" - - train_data: DataLoader - test_data: DataLoader - - def __init__(self, batch_size: int): - datasets_path = Path(__file__).resolve().parent.parent / 'datasets' / 'mnist.npz' - datasets = numpy.load(datasets_path) - - # 所有图片均为黑底白字 - # 6万张训练图片:60000x28x28。标签只有第一维。 - train_images = datasets['x_train'] - train_labels = datasets['y_train'] - # 1万张测试图片:10000x28x28。标签只有第一维。 - test_images = datasets['x_test'] - test_labels = datasets['y_test'] - - # 定义数据转换器 - trans = tvtrans.Compose([ - # 从uint8转换为float32并自动归一化到0-1区间 - # tvtrans.ToTensor(), - tvtrans.ToImage(), - tvtrans.ToDtype(torch.float32, scale=True), - # 为了符合后面图像的输入颜色通道条件,要在最后挤出一个新的维度 - #tvtrans.Lambda(lambda x: x.unsqueeze(-1)) - ]) - - # 创建数据集 - train_dataset = MnistDataset(train_images, - train_labels, - transform=trans) - test_dataset = MnistDataset(test_images, test_labels, transform=trans) - - # 赋值到自身 - self.train_data = DataLoader(dataset=train_dataset, - batch_size=batch_size, - shuffle=False) - self.test_data = DataLoader(dataset=test_dataset, - batch_size=batch_size, - shuffle=False) - - class Trainer: N_EPOCH: typing.ClassVar[int] = 5 N_BATCH_SIZE: typing.ClassVar[int] = 1000 device: torch.device - data_source: DataSource - cnn: CNN + data_source: MnistDataSource + model: CNN def __init__(self): self.device = gpu_utils.get_gpu_device() - self.data_source = DataSource(Trainer.N_BATCH_SIZE) - self.cnn = CNN().to(self.device) + self.data_source = MnistDataSource(Trainer.N_BATCH_SIZE) + self.model = CNN().to(self.device) + # 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。 + torchinfo.summary(self.model, (Trainer.N_BATCH_SIZE, 1, 28, 28)) def train(self): - optimizer = torch.optim.Adam(self.cnn.parameters()) + optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7) + # optimizer = torch.optim.AdamW( + # self.model.parameters(), + # lr=0.001, # 两者默认学习率都是 0.001 + # betas=(0.9, 0.999), # 两者默认值相同 + # eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon + # weight_decay=0.0, # 两者默认都是 0 + # amsgrad=False # 两者默认都是 False + # ) loss_func = torch.nn.CrossEntropyLoss() for epoch in range(Trainer.N_EPOCH): - self.cnn.train() + self.model.train() batch_images: torch.Tensor batch_labels: torch.Tensor - for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_data): + for batch_index, (batch_images, batch_labels) in enumerate(self.data_source.train_loader): gpu_images = batch_images.to(self.device) gpu_labels = batch_labels.to(self.device) - optimizer.zero_grad() - prediction: torch.Tensor = self.cnn(gpu_images) + prediction: torch.Tensor = self.model(gpu_images) loss: torch.Tensor = loss_func(prediction, gpu_labels) + optimizer.zero_grad() loss.backward() optimizer.step() @@ -151,20 +63,20 @@ class Trainer: file_dir_path = Path(__file__).resolve().parent.parent / 'models' file_dir_path.mkdir(parents=True, exist_ok=True) file_path = file_dir_path / 'cnn.pth' - torch.save(self.cnn.state_dict(), file_path) + torch.save(self.model.state_dict(), file_path) print(f'模型已保存至:{file_path}') def test(self): - self.cnn.eval() + self.model.eval() correct_sum = 0 total_sum = 0 with torch.no_grad(): - for batch_images, batch_labels in self.data_source.test_data: + for batch_images, batch_labels in self.data_source.test_loader: gpu_images = batch_images.to(self.device) gpu_labels = batch_labels.to(self.device) - possibilities: torch.Tensor = self.cnn(gpu_images) + possibilities: torch.Tensor = self.model(gpu_images) # 输出出来是10个数字各自的可能性,所以要选取最高可能性的那个对比 # 在dim=1上找最大的那个,就选那个。dim=0是批次所以不管他。 _, prediction = possibilities.max(1) @@ -175,13 +87,105 @@ class Trainer: test_acc = 100. * correct_sum / total_sum print(f"准确率: {test_acc:.4f}%,共测试了{total_sum}张图片") - def main(): trainer = Trainer() trainer.train() trainer.save() trainer.test() +# N_EPOCH: int = 5 +# N_BATCH_SIZE: int = 1000 +# N_LOG_INTERVAL: int = 10 + +# class Trainer: + +# device: torch.device +# data_source: MnistDataSource +# model: CNN + +# trainer: Engine +# train_evaluator: Engine +# test_evaluator: Engine + +# def __init__(self): +# self.device = gpu_utils.get_gpu_device() +# self.model = CNN().to(self.device) +# self.data_source = MnistDataSource(batch_size=N_BATCH_SIZE) +# # 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。 +# torchinfo.summary(self.model, (N_BATCH_SIZE, 1, 28, 28)) + +# #optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7) +# optimizer = torch.optim.AdamW( +# self.model.parameters(), +# lr=0.001, # 两者默认学习率都是 0.001 +# betas=(0.9, 0.999), # 两者默认值相同 +# eps=1e-07, # 【关键】匹配 TensorFlow 的默认 epsilon +# weight_decay=0.0, # 两者默认都是 0 +# amsgrad=False # 两者默认都是 False +# ) +# criterion = torch.nn.CrossEntropyLoss() +# self.trainer = ignite.engine.create_supervised_trainer( +# self.model, optimizer, criterion, self.device +# ) + +# eval_metrics = { +# "accuracy": ignite.metrics.Accuracy(device=self.device), +# "loss": ignite.metrics.Loss(criterion, device=self.device) +# } +# self.train_evaluator = ignite.engine.create_supervised_evaluator( +# self.model, metrics=eval_metrics, device=self.device) +# self.test_evaluator = ignite.engine.create_supervised_evaluator( +# self.model, metrics=eval_metrics, device=self.device) + +# self.trainer.add_event_handler( +# Events.ITERATION_COMPLETED(every=N_LOG_INTERVAL), +# lambda engine: self.log_intrain_loss(engine) +# ) +# self.trainer.add_event_handler( +# Events.EPOCH_COMPLETED, +# lambda trainer: self.log_train_results(trainer) +# ) +# self.trainer.add_event_handler( +# Events.COMPLETED, +# lambda _: self.log_test_results() +# ) +# self.trainer.add_event_handler( +# Events.COMPLETED, +# lambda _: self.save_model() +# ) + +# progressbar = ProgressBar() +# progressbar.attach(self.trainer) + +# def run(self): +# self.trainer.run(self.data_source.train_loader, max_epochs=N_EPOCH) + +# def log_intrain_loss(self, engine: Engine): +# print(f"Epoch: {engine.state.epoch}, Loss: {engine.state.output:.4f}\r", end="") + +# def log_train_results(self, trainer: Engine): +# self.train_evaluator.run(self.data_source.train_loader) +# metrics = self.train_evaluator.state.metrics +# print() +# print(f"Training - Epoch: {trainer.state.epoch}, Avg Accuracy: {metrics['accuracy']:.4f}, Avg Loss: {metrics['loss']:.4f}") + +# def log_test_results(self): +# self.test_evaluator.run(self.data_source.test_loader) +# metrics = self.test_evaluator.state.metrics +# print(f"Test - Avg Accuracy: {metrics['accuracy']:.4f} Avg Loss: {metrics['loss']:.4f}") + +# def save_model(self): +# file_dir_path = Path(__file__).resolve().parent.parent / 'models' +# file_dir_path.mkdir(parents=True, exist_ok=True) +# file_path = file_dir_path / 'cnn.pth' +# torch.save(self.model.state_dict(), file_path) +# print(f'Model was saved into: {file_path}') + +# def main(): +# trainer = Trainer() +# trainer.run() + + if __name__ == "__main__": gpu_utils.print_gpu_availability() diff --git a/exp2/test_images/.gitignore b/exp2/test_images/.gitignore new file mode 100644 index 0000000..df23673 --- /dev/null +++ b/exp2/test_images/.gitignore @@ -0,0 +1,2 @@ +# Ignore all test images +*.png diff --git a/exp3/datasets/.gitignore b/exp3/datasets/.gitignore new file mode 100644 index 0000000..6c7c288 --- /dev/null +++ b/exp3/datasets/.gitignore @@ -0,0 +1,2 @@ +# Ignore datasets and processed datasets +*.txt \ No newline at end of file diff --git a/exp3/models/.gitignore b/exp3/models/.gitignore new file mode 100644 index 0000000..0c51985 --- /dev/null +++ b/exp3/models/.gitignore @@ -0,0 +1,2 @@ +# Ignore every saved model files +*.pth \ No newline at end of file diff --git a/exp3/modified/dataset.py b/exp3/modified/dataset.py new file mode 100644 index 0000000..8182bb0 --- /dev/null +++ b/exp3/modified/dataset.py @@ -0,0 +1,171 @@ +from pathlib import Path +import typing +import pickle +from collections import Counter +import numpy + +class Tokenizer: + """分词器""" + + token_dict: dict[str, int] + """词->编号的映射""" + token_dict_rev: dict[int, str] + """编号->词的映射""" + vocab_size: int + """词汇表大小""" + + def __init__(self, token_dict: dict[str, int]): + self.token_dict = token_dict + self.token_dict_rev = {value: key for key, value in self.token_dict.items()} + self.vocab_size = len(self.token_dict) + + def id_to_token(self, token_id: int) -> str: + """ + 给定一个编号,查找词汇表中对应的词。 + + :param token_id: 带查找词的编号 + :return: 编号对应的词 + """ + return self.token_dict_rev[token_id] + + def token_to_id(self, token: str): + """ + 给定一个词,查找它在词汇表中的编号。 + 未找到则返回低频词[UNK]的编号。 + + :param token: 带查找编号的词 + :return: 词的编号 + """ + return self.token_dict.get(token, self.token_dict['[UNK]']) + + def encode(self, tokens: str) -> list[int]: + """ + 给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列 + + :param tokens: 待编码字符串 + :return: 编号序列 + """ + # 加上开始标记 + token_ids: list[int] = [self.token_to_id('[CLS]'), ] + # 加入字符串编号序列 + for token in tokens: + token_ids.append(self.token_to_id(token)) + # 加上结束标记 + token_ids.append(self.token_to_id('[SEP]')) + return token_ids + + def decode(self, token_ids: typing.Iterable[int]) -> str: + """ + 给定一个编号序列,将它解码成字符串 + + :param token_ids: 待解码的编号序列 + :return: 解码出的字符串 + """ + # 起止标记字符特殊处理 + spec_tokens = {'[CLS]', '[SEP]'} + # 保存解码出的字符的list + tokens: list[str] = [] + for token_id in token_ids: + token = self.id_to_token(token_id) + if token in spec_tokens: + continue + tokens.append(token) + # 拼接字符串 + return ''.join(tokens) + + +class PoetryDataset: + """古诗词数据集加载器""" + + BAD_WORDS: typing.ClassVar[list[str]] = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']'] + """禁用词,包含如下字符的唐诗将被忽略""" + MAX_SEG_LEN: typing.ClassVar[int] = 64 + """句子最大长度""" + MIN_WORD_FREQ: typing.ClassVar[int] = 8 + """最小词频""" + + tokenizer: Tokenizer + """分词器""" + poetry: list[str] + """古诗词数据集""" + + def __init__(self, force_reclean: bool = False): + # 加载古诗,然后统计词频构建分词器 + self.poetry = self.load_poetry(force_reclean) + self.tokenizer = self.build_tokenizer(self.poetry) + + def load_poetry(self, force_reclean: bool = False) -> list[str]: + """加载古诗词数据集""" + if force_reclean or (not self.get_clean_dataset_path().is_file()): + return self.load_poetry_from_dirty() + else: + return self.load_poetry_from_clean() + + def load_poetry_from_clean(self) -> list[str]: + """直接读取清洗后的数据""" + with open(self.get_clean_dataset_path(), 'rb') as f: + return pickle.load(f) + + def load_poetry_from_dirty(self) -> list[str]: + """从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据""" + with open(self.get_dirty_dataset_path(), 'r', encoding='utf-8') as f: + lines = f.readlines() + # 将冒号统一成相同格式 + lines = [line.replace(':', ':') for line in lines] + + # 数据集列表 + poetry: list[str] = [] + # 逐行处理读取到的数据 + for line in lines: + # 删除空白字符 + line = line.strip() + # 有且只能有一个冒号用来分割标题 + if line.count(':') != 1: continue + # 获取后半部分 + _, last_part = line.split(':') + # 长度不能超过最大长度 + if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2: + continue + # 不能包含禁止词 + for bad_word in PoetryDataset.BAD_WORDS: + if bad_word in last_part: + break + else: + # 如果循环正常结束,就表明没有bad words,推入数据列表 + poetry.append(last_part) + + # 数据清理完毕 + # 写入干净数据 + with open(self.get_clean_dataset_path(), 'wb') as f: + pickle.dump(poetry, f) + # 返回结果 + return poetry + + def get_clean_dataset_path(self) -> Path: + return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle' + + def get_dirty_dataset_path(self) -> Path: + return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt' + + def build_tokenizer(self, poetry: list[str]) -> Tokenizer: + """统计词频,并构建分词器""" + # 统计词频 + counter: Counter[str] = Counter() + for line in poetry: + counter.update(line) + # 过滤掉低频词 + tokens = ((token, count) for token, count in counter.items() if count >= PoetryDataset.MIN_WORD_FREQ) + # 按词频排序 + tokens = sorted(tokens, key=lambda x: -x[1]) + # 去掉词频,只保留词列表 + tokens = list(token for token, _ in tokens) + + # 将特殊词和数据集中的词拼接起来 + tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens + # 创建词典 token->id映射关系 + token_id_dict = dict(zip(tokens, range(len(tokens)))) + # 使用新词典重新建立分词器 + tokenizer = Tokenizer(token_id_dict) + # 直接返回,此处无需混洗数据 + return tokenizer + diff --git a/exp3/modified/settings.py b/exp3/modified/settings.py new file mode 100644 index 0000000..d0a692f --- /dev/null +++ b/exp3/modified/settings.py @@ -0,0 +1,12 @@ +from pathlib import Path + +BATCH_SIZE: int = 16 +"""训练的batch size""" + +def get_saved_model_path() -> Path: + """ + 获取训练完毕的模型进行保存的路径。 + + :return: 模型参数保存的路径。 + """ + return Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth' diff --git a/exp3/modified/utils.py b/exp3/modified/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/exp3/source/dataset.py b/exp3/source/dataset.py new file mode 100644 index 0000000..5b89287 --- /dev/null +++ b/exp3/source/dataset.py @@ -0,0 +1,199 @@ +#ANLI College of Artificial Intelligence + +from collections import Counter +import math +import numpy as np +import tensorflow as tf +import settings + + +class Tokenizer: + """ + 分词器 + """ + + def __init__(self, token_dict): + # 词->编号的映射 + self.token_dict = token_dict + # 编号->词的映射 + self.token_dict_rev = {value: key for key, value in self.token_dict.items()} + # 词汇表大小 + self.vocab_size = len(self.token_dict) + + def id_to_token(self, token_id): + """ + 给定一个编号,查找词汇表中对应的词 + :param token_id: 带查找词的编号 + :return: 编号对应的词 + """ + return self.token_dict_rev[token_id] + + def token_to_id(self, token): + """ + 给定一个词,查找它在词汇表中的编号 + 未找到则返回低频词[UNK]的编号 + :param token: 带查找编号的词 + :return: 词的编号 + """ + return self.token_dict.get(token, self.token_dict['[UNK]']) + + def encode(self, tokens): + """ + 给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列 + :param tokens: 待编码字符串 + :return: 编号序列 + """ + # 加上开始标记 + token_ids = [self.token_to_id('[CLS]'), ] + # 加入字符串编号序列 + for token in tokens: + token_ids.append(self.token_to_id(token)) + # 加上结束标记 + token_ids.append(self.token_to_id('[SEP]')) + return token_ids + + def decode(self, token_ids): + """ + 给定一个编号序列,将它解码成字符串 + :param token_ids: 待解码的编号序列 + :return: 解码出的字符串 + """ + # 起止标记字符特殊处理 + spec_tokens = {'[CLS]', '[SEP]'} + # 保存解码出的字符的list + tokens = [] + for token_id in token_ids: + token = self.id_to_token(token_id) + if token in spec_tokens: + continue + tokens.append(token) + # 拼接字符串 + return ''.join(tokens) + + +# 禁用词 +disallowed_words = settings.DISALLOWED_WORDS +# 句子最大长度 +max_len = settings.MAX_LEN +# 最小词频 +min_word_frequency = settings.MIN_WORD_FREQUENCY +# mini batch 大小 +batch_size = settings.BATCH_SIZE + +# 加载数据集 +with open(settings.DATASET_PATH, 'r', encoding='utf-8') as f: + lines = f.readlines() + # 将冒号统一成相同格式 + lines = [line.replace(':', ':') for line in lines] +# 数据集列表 +poetry = [] +# 逐行处理读取到的数据 +for line in lines: + # 有且只能有一个冒号用来分割标题 + if line.count(':') != 1: + continue + # 后半部分不能包含禁止词 + __, last_part = line.split(':') + ignore_flag = False + for dis_word in disallowed_words: + if dis_word in last_part: + ignore_flag = True + break + if ignore_flag: + continue + # 长度不能超过最大长度 + if len(last_part) > max_len - 2: + continue + poetry.append(last_part.replace('\n', '')) + +# 统计词频 +counter = Counter() +for line in poetry: + counter.update(line) +# 过滤掉低频词 +_tokens = [(token, count) for token, count in counter.items() if count >= min_word_frequency] +# 按词频排序 +_tokens = sorted(_tokens, key=lambda x: -x[1]) +# 去掉词频,只保留词列表 +_tokens = [token for token, count in _tokens] + +# 将特殊词和数据集中的词拼接起来 +_tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + _tokens +# 创建词典 token->id映射关系 +token_id_dict = dict(zip(_tokens, range(len(_tokens)))) +# 使用新词典重新建立分词器 +tokenizer = Tokenizer(token_id_dict) +# 混洗数据 +np.random.shuffle(poetry) + + +class PoetryDataGenerator: + """ + 古诗数据集生成器 + """ + + def __init__(self, data, random=False): + # 数据集 + self.data = data + # batch size + self.batch_size = batch_size + # 每个epoch迭代的步数 + self.steps = int(math.floor(len(self.data) / self.batch_size)) + # 每个epoch开始时是否随机混洗 + self.random = random + + def sequence_padding(self, data, length=None, padding=None): + """ + 将给定数据填充到相同长度 + :param data: 待填充数据 + :param length: 填充后的长度,不传递此参数则使用data中的最大长度 + :param padding: 用于填充的数据,不传递此参数则使用[PAD]的对应编号 + :return: 填充后的数据 + """ + # 计算填充长度 + if length is None: + length = max(map(len, data)) + # 计算填充数据 + if padding is None: + padding = tokenizer.token_to_id('[PAD]') + # 开始填充 + outputs = [] + for line in data: + padding_length = length - len(line) + # 不足就进行填充 + if padding_length > 0: + outputs.append(np.concatenate([line, [padding] * padding_length])) + # 超过就进行截断 + else: + outputs.append(line[:length]) + return np.array(outputs) + + def __len__(self): + return self.steps + + def __iter__(self): + total = len(self.data) + # 是否随机混洗 + if self.random: + np.random.shuffle(self.data) + # 迭代一个epoch,每次yield一个batch + for start in range(0, total, self.batch_size): + end = min(start + self.batch_size, total) + batch_data = [] + # 逐一对古诗进行编码 + for single_data in self.data[start:end]: + batch_data.append(tokenizer.encode(single_data)) + # 填充为相同长度 + batch_data = self.sequence_padding(batch_data) + # yield x,y + yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], tokenizer.vocab_size) + del batch_data + + def for_fit(self): + """ + 创建一个生成器,用于训练 + """ + # 死循环,当数据训练一个epoch之后,重新迭代数据 + while True: + # 委托生成器 + yield from self.__iter__() \ No newline at end of file diff --git a/exp3/source/eval.py b/exp3/source/eval.py new file mode 100644 index 0000000..fa71b02 --- /dev/null +++ b/exp3/source/eval.py @@ -0,0 +1,16 @@ +#ANLI College of Artificial Intelligence + + +import tensorflow as tf +from dataset import tokenizer +import settings +import utils + +# 加载训练好的模型 +model = tf.keras.models.load_model(settings.BEST_MODEL_PATH) +# 随机生成一首诗 +print(utils.generate_random_poetry(tokenizer, model)) +# 给出部分信息的情况下,随机生成剩余部分 +print(utils.generate_random_poetry(tokenizer, model, s='床前明月光,')) +# 生成藏头诗 +print(utils.generate_acrostic(tokenizer, model, head='好好学习天天向上')) diff --git a/exp3/source/settings.py b/exp3/source/settings.py new file mode 100644 index 0000000..46ec8e9 --- /dev/null +++ b/exp3/source/settings.py @@ -0,0 +1,19 @@ +#ANLI College of Artificial Intelligence + + +# 禁用词,包含如下字符的唐诗将被忽略 +DISALLOWED_WORDS = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']'] +# 句子最大长度 +MAX_LEN = 64 +# 最小词频 +MIN_WORD_FREQUENCY = 8 +# 训练的batch size +BATCH_SIZE = 16 +# 数据集路径 +DATASET_PATH = './poetry.txt' +# 每个epoch训练完成后,随机生成SHOW_NUM首古诗作为展示 +SHOW_NUM = 5 +# 共训练多少个epoch +TRAIN_EPOCHS = 10 +# 最佳权重保存路径 +BEST_MODEL_PATH = './best_model.h5' diff --git a/exp3/source/train.py b/exp3/source/train.py new file mode 100644 index 0000000..6515838 --- /dev/null +++ b/exp3/source/train.py @@ -0,0 +1,35 @@ +import tensorflow as tf +from dataset import PoetryDataGenerator, tokenizer, poetry +import settings +import utils + +model = tf.keras.Sequential([ + tf.keras.layers.Input((None,)), + tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128), + tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True), + tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True), + tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')), + +]) +model.summary() +model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy) + +class Evaluate(tf.keras.callbacks.Callback): + + def __init__(self): + super().__init__() + self.lowest = 1e10 + + def on_epoch_end(self, epoch, logs=None): + if logs['loss'] <= self.lowest: + self.lowest = logs['loss'] + model.save(settings.BEST_MODEL_PATH) + print() + for i in range(settings.SHOW_NUM): + print(utils.generate_random_poetry(tokenizer, model)) + +data_generator = PoetryDataGenerator(poetry, random=False) +model.fit_generator(data_generator.for_fit(), + steps_per_epoch=data_generator.steps, + epochs=settings.TRAIN_EPOCHS, + callbacks=[Evaluate()]) diff --git a/exp3/source/utils.py b/exp3/source/utils.py new file mode 100644 index 0000000..55f23c4 --- /dev/null +++ b/exp3/source/utils.py @@ -0,0 +1,86 @@ + +import numpy as np +import settings + + +def generate_random_poetry(tokenizer, model, s=''): + """ + 随机生成一首诗 + :param tokenizer: 分词器 + :param model: 用于生成古诗的模型 + :param s: 用于生成古诗的起始字符串,默认为空串 + :return: 一个字符串,表示一首古诗 + """ + # 将初始字符串转成token + token_ids = tokenizer.encode(s) + # 去掉结束标记[SEP] + token_ids = token_ids[:-1] + while len(token_ids) < settings.MAX_LEN: + # 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布 + output = model(np.array([token_ids, ], dtype=np.int32)) + _probas = output.numpy()[0, -1, 3:] + del output + # print(_probas) + # 按照出现概率,对所有token倒序排列 + p_args = _probas.argsort()[::-1][:100] + # 排列后的概率顺序 + p = _probas[p_args] + # 先对概率归一 + p = p / sum(p) + # 再按照预测出的概率,随机选择一个词作为预测结果 + target_index = np.random.choice(len(p), p=p) + target = p_args[target_index] + 3 + # 保存 + token_ids.append(target) + if target == 3: + break + return tokenizer.decode(token_ids) + + +def generate_acrostic(tokenizer, model, head): + """ + 随机生成一首藏头诗 + :param tokenizer: 分词器 + :param model: 用于生成古诗的模型 + :param head: 藏头诗的头 + :return: 一个字符串,表示一首古诗 + """ + # 使用空串初始化token_ids,加入[CLS] + token_ids = tokenizer.encode('') + token_ids = token_ids[:-1] + # 标点符号,这里简单的只把逗号和句号作为标点 + punctuations = [',', '。'] + punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations} + # 缓存生成的诗的list + poetry = [] + # 对于藏头诗中的每一个字,都生成一个短句 + for ch in head: + # 先记录下这个字 + poetry.append(ch) + # 将藏头诗的字符转成token id + token_id = tokenizer.token_to_id(ch) + # 加入到列表中去 + token_ids.append(token_id) + # 开始生成一个短句 + while True: + # 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布 + output = model(np.array([token_ids, ], dtype=np.int32)) + _probas = output.numpy()[0, -1, 3:] + del output + # 按照出现概率,对所有token倒序排列 + p_args = _probas.argsort()[::-1][:100] + # 排列后的概率顺序 + p = _probas[p_args] + # 先对概率归一 + p = p / sum(p) + # 再按照预测出的概率,随机选择一个词作为预测结果 + target_index = np.random.choice(len(p), p=p) + target = p_args[target_index] + 3 + # 保存 + token_ids.append(target) + # 只有不是特殊字符时,才保存到poetry里面去 + if target > 3: + poetry.append(tokenizer.id_to_token(target)) + if target in punctuation_ids: + break + return ''.join(poetry) diff --git a/pyproject.toml b/pyproject.toml index 326101a..61feb7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,11 @@ dependencies = [ "datasets>=4.3.0", "matplotlib>=3.10.7", "numpy>=2.3.4", - "torch>=2.9.0", - "torchvision>=0.24.0", + "pillow>=12.0.0", + "pytorch-ignite>=0.5.3", + "torch>=2.9.0", + "torchinfo>=1.8.0", + "torchvision>=0.24.0", ] [tool.uv.sources] diff --git a/uv.lock b/uv.lock index dfe9dbb..0db21b0 100644 --- a/uv.lock +++ b/uv.lock @@ -393,8 +393,11 @@ dependencies = [ { name = "datasets" }, { name = "matplotlib" }, { name = "numpy" }, + { name = "pillow" }, + { name = "pytorch-ignite" }, { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, { name = "torch", version = "2.9.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "torchinfo" }, { name = "torchvision", version = "0.24.0", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "torchvision", version = "0.24.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, { name = "torchvision", version = "0.24.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or sys_platform == 'win32'" }, @@ -405,8 +408,11 @@ requires-dist = [ { name = "datasets", specifier = ">=4.3.0" }, { name = "matplotlib", specifier = ">=3.10.7" }, { name = "numpy", specifier = ">=2.3.4" }, + { name = "pillow", specifier = ">=12.0.0" }, + { name = "pytorch-ignite", specifier = ">=0.5.3" }, { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.9.0" }, { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.9.0", index = "https://download.pytorch.org/whl/cu126" }, + { name = "torchinfo", specifier = ">=1.8.0" }, { name = "torchvision", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=0.24.0" }, { name = "torchvision", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=0.24.0", index = "https://download.pytorch.org/whl/cu126" }, ] @@ -1639,6 +1645,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "pytorch-ignite" +version = "0.5.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, + { name = "torch", version = "2.9.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/e5/7fe880b24de30b4eadc8d997ea8d3c4a8f507b1a34dcdced08d88f665ee3/pytorch_ignite-0.5.3.tar.gz", hash = "sha256:75c645f02fea66cc80c1998ade3f8402e0e6b6d73f3f4ad727c171f6e93874f4", size = 7506607, upload-time = "2025-10-16T00:42:05.142Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/ea/f6d5ee7433a5a1c1e4746e2a4e9a222eab545fdbe04b66754ffdab479ee8/pytorch_ignite-0.5.3-py3-none-any.whl", hash = "sha256:4ced7539c690a3b6f3116da7878389954dff787c33669f83b38221f3746bc63e", size = 343802, upload-time = "2025-10-16T00:41:55.738Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -1848,6 +1868,15 @@ wheels = [ { url = "https://download.pytorch.org/whl/cu126/torch-2.9.0%2Bcu126-cp314-cp314t-win_amd64.whl", hash = "sha256:d8fdfc45ba30cb5c23971b35ae72c6fe246596022b574bd37dd0c775958f70b1" }, ] +[[package]] +name = "torchinfo" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/53/d9/2b811d1c0812e9ef23e6cf2dbe022becbe6c5ab065e33fd80ee05c0cd996/torchinfo-1.8.0.tar.gz", hash = "sha256:72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9", size = 25880, upload-time = "2023-05-14T19:23:26.377Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/25/973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1/torchinfo-1.8.0-py3-none-any.whl", hash = "sha256:2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46", size = 23377, upload-time = "2023-05-14T19:23:24.141Z" }, +] + [[package]] name = "torchvision" version = "0.24.0"