2025-11-30 16:24:32 +08:00
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
import typing
|
|
|
|
|
|
import pickle
|
|
|
|
|
|
from collections import Counter
|
|
|
|
|
|
import numpy
|
2025-12-06 13:10:02 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
import settings
|
|
|
|
|
|
|
|
|
|
|
|
TOKEN_PAD: str = '[PAD]'
|
|
|
|
|
|
"""使用古诗词数据时的特殊字符,RNN填充时使用的填充字符。"""
|
|
|
|
|
|
TOKEN_UNK: str = '[UNK]'
|
|
|
|
|
|
"""使用古诗词数据时的特殊字符,词频不足的生僻字。"""
|
|
|
|
|
|
TOKEN_CLS: str = '[CLS]'
|
|
|
|
|
|
"""使用古诗词数据时的特殊字符,标记古诗词开始。"""
|
|
|
|
|
|
TOKEN_SEP: str = '[SEP]'
|
|
|
|
|
|
"""使用古诗词数据时的特殊字符,标记古诗词结束。"""
|
2025-11-30 16:24:32 +08:00
|
|
|
|
|
|
|
|
|
|
class Tokenizer:
|
|
|
|
|
|
"""分词器"""
|
|
|
|
|
|
|
|
|
|
|
|
token_dict: dict[str, int]
|
|
|
|
|
|
"""词->编号的映射"""
|
|
|
|
|
|
token_dict_rev: dict[int, str]
|
|
|
|
|
|
"""编号->词的映射"""
|
|
|
|
|
|
vocab_size: int
|
|
|
|
|
|
"""词汇表大小"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, token_dict: dict[str, int]):
|
|
|
|
|
|
self.token_dict = token_dict
|
|
|
|
|
|
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
|
|
|
|
|
|
self.vocab_size = len(self.token_dict)
|
|
|
|
|
|
|
|
|
|
|
|
def id_to_token(self, token_id: int) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
给定一个编号,查找词汇表中对应的词。
|
|
|
|
|
|
|
|
|
|
|
|
:param token_id: 带查找词的编号
|
|
|
|
|
|
:return: 编号对应的词
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self.token_dict_rev[token_id]
|
|
|
|
|
|
|
|
|
|
|
|
def token_to_id(self, token: str):
|
|
|
|
|
|
"""
|
|
|
|
|
|
给定一个词,查找它在词汇表中的编号。
|
|
|
|
|
|
未找到则返回低频词[UNK]的编号。
|
|
|
|
|
|
|
|
|
|
|
|
:param token: 带查找编号的词
|
|
|
|
|
|
:return: 词的编号
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self.token_dict.get(token, self.token_dict['[UNK]'])
|
|
|
|
|
|
|
|
|
|
|
|
def encode(self, tokens: str) -> list[int]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
|
|
|
|
|
|
|
|
|
|
|
|
:param tokens: 待编码字符串
|
|
|
|
|
|
:return: 编号序列
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 加上开始标记
|
2025-12-06 13:10:02 +08:00
|
|
|
|
token_ids: list[int] = [self.token_to_id(TOKEN_CLS), ]
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 加入字符串编号序列
|
|
|
|
|
|
for token in tokens:
|
|
|
|
|
|
token_ids.append(self.token_to_id(token))
|
|
|
|
|
|
# 加上结束标记
|
2025-12-06 13:10:02 +08:00
|
|
|
|
token_ids.append(self.token_to_id(TOKEN_SEP))
|
2025-11-30 16:24:32 +08:00
|
|
|
|
return token_ids
|
|
|
|
|
|
|
|
|
|
|
|
def decode(self, token_ids: typing.Iterable[int]) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
给定一个编号序列,将它解码成字符串
|
|
|
|
|
|
|
|
|
|
|
|
:param token_ids: 待解码的编号序列
|
|
|
|
|
|
:return: 解码出的字符串
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 起止标记字符特殊处理
|
2025-12-06 13:10:02 +08:00
|
|
|
|
spec_tokens = {TOKEN_CLS, TOKEN_SEP}
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 保存解码出的字符的list
|
|
|
|
|
|
tokens: list[str] = []
|
|
|
|
|
|
for token_id in token_ids:
|
|
|
|
|
|
token = self.id_to_token(token_id)
|
|
|
|
|
|
if token in spec_tokens:
|
|
|
|
|
|
continue
|
|
|
|
|
|
tokens.append(token)
|
|
|
|
|
|
# 拼接字符串
|
|
|
|
|
|
return ''.join(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-12-06 13:10:02 +08:00
|
|
|
|
class PoetryPreprocessor:
|
|
|
|
|
|
"""
|
|
|
|
|
|
古诗词数据集的预处理器。
|
|
|
|
|
|
|
|
|
|
|
|
该类负责古诗词数据的读取,清洗和数据持久化。
|
|
|
|
|
|
"""
|
2025-11-30 16:24:32 +08:00
|
|
|
|
|
|
|
|
|
|
tokenizer: Tokenizer
|
|
|
|
|
|
"""分词器"""
|
|
|
|
|
|
poetry: list[str]
|
2025-12-02 23:07:27 +08:00
|
|
|
|
"""古诗词数据集,每一项是一首诗"""
|
2025-11-30 16:24:32 +08:00
|
|
|
|
|
2025-12-06 13:10:02 +08:00
|
|
|
|
def __init__(self, force_reclean: bool=False):
|
|
|
|
|
|
# 加载古诗词数据集
|
|
|
|
|
|
if force_reclean or (not settings.CLEAN_DATASET_PATH.is_file()):
|
|
|
|
|
|
(self.poetry, self.tokenizer) = self.__load_from_dirty()
|
2025-11-30 16:24:32 +08:00
|
|
|
|
else:
|
2025-12-06 13:10:02 +08:00
|
|
|
|
(self.poetry, self.tokenizer) = self.__load_from_clean()
|
2025-11-30 16:24:32 +08:00
|
|
|
|
|
2025-12-06 13:10:02 +08:00
|
|
|
|
def __load_from_clean(self) -> tuple[list[str], Tokenizer]:
|
2025-11-30 16:24:32 +08:00
|
|
|
|
"""直接读取清洗后的数据"""
|
2025-12-06 13:10:02 +08:00
|
|
|
|
with open(settings.CLEAN_DATASET_PATH, 'rb') as f:
|
2025-11-30 16:24:32 +08:00
|
|
|
|
return pickle.load(f)
|
|
|
|
|
|
|
2025-12-06 13:10:02 +08:00
|
|
|
|
def __load_from_dirty(self) -> tuple[list[str], Tokenizer]:
|
2025-11-30 16:24:32 +08:00
|
|
|
|
"""从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据"""
|
2025-12-06 13:10:02 +08:00
|
|
|
|
# 加载脏的古诗数据
|
|
|
|
|
|
with open(settings.DIRTY_DATASET_PATH, 'r', encoding='utf-8') as f:
|
2025-11-30 16:24:32 +08:00
|
|
|
|
lines = f.readlines()
|
|
|
|
|
|
|
2025-12-06 13:10:02 +08:00
|
|
|
|
# 清洗古诗数据
|
|
|
|
|
|
poetry = self.__wash_dirty_poetry(lines)
|
|
|
|
|
|
# 构建分词器
|
|
|
|
|
|
tokenizer = self.__build_tokenizer(poetry)
|
|
|
|
|
|
|
|
|
|
|
|
# 数据清理完毕
|
|
|
|
|
|
# 写入干净数据
|
|
|
|
|
|
with open(settings.CLEAN_DATASET_PATH, 'wb') as f:
|
|
|
|
|
|
pickle.dump((poetry, tokenizer), f)
|
|
|
|
|
|
|
|
|
|
|
|
# 返回结果
|
|
|
|
|
|
return poetry, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
def __wash_dirty_poetry(self, poetry: list[str]) -> list[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
清洗给定的古诗数据。
|
|
|
|
|
|
|
|
|
|
|
|
:param poetry: 要清洗的古诗数据,每一行是一首古诗。
|
|
|
|
|
|
古诗开头是标题,然后是一个冒号(全角或半角),然后是古诗主体。
|
|
|
|
|
|
:return: 清洗完毕的古诗。
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 禁用词列表,包含如下字符的诗歌将被忽略
|
|
|
|
|
|
BAD_WORDS = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 数据集列表
|
2025-12-06 13:10:02 +08:00
|
|
|
|
clean_poetry: list[str] = []
|
|
|
|
|
|
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 逐行处理读取到的数据
|
2025-12-06 13:10:02 +08:00
|
|
|
|
for line in poetry:
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 删除空白字符
|
|
|
|
|
|
line = line.strip()
|
2025-12-06 13:10:02 +08:00
|
|
|
|
# 将全角冒号替换为半角的
|
|
|
|
|
|
line = line.replace(':', ':')
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 有且只能有一个冒号用来分割标题
|
|
|
|
|
|
if line.count(':') != 1: continue
|
2025-12-02 23:07:27 +08:00
|
|
|
|
# 获取后半部分(删除标题)
|
2025-11-30 16:24:32 +08:00
|
|
|
|
_, last_part = line.split(':')
|
2025-12-06 13:10:02 +08:00
|
|
|
|
# 长度不能超过最大长度(减去2是因为古诗首尾要加特殊符号)
|
|
|
|
|
|
if len(last_part) > settings.POETRY_MAX_LEN - 2:
|
2025-11-30 16:24:32 +08:00
|
|
|
|
continue
|
|
|
|
|
|
# 不能包含禁止词
|
2025-12-06 13:10:02 +08:00
|
|
|
|
for bad_word in BAD_WORDS:
|
2025-11-30 16:24:32 +08:00
|
|
|
|
if bad_word in last_part:
|
|
|
|
|
|
break
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果循环正常结束,就表明没有bad words,推入数据列表
|
2025-12-06 13:10:02 +08:00
|
|
|
|
clean_poetry.append(last_part)
|
2025-11-30 16:24:32 +08:00
|
|
|
|
|
2025-12-06 13:10:02 +08:00
|
|
|
|
# 返回清洗完毕的结果
|
|
|
|
|
|
return clean_poetry
|
2025-11-30 16:24:32 +08:00
|
|
|
|
|
2025-12-06 13:10:02 +08:00
|
|
|
|
def __build_tokenizer(self, poetry: list[str]) -> Tokenizer:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据给定古诗统计词频,并构建分词器。
|
|
|
|
|
|
|
|
|
|
|
|
:param poetry: 清洗完毕后的古诗,每一行是一句诗。
|
|
|
|
|
|
:return: 构建完毕的分词器。
|
|
|
|
|
|
"""
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 统计词频
|
|
|
|
|
|
counter: Counter[str] = Counter()
|
|
|
|
|
|
for line in poetry:
|
|
|
|
|
|
counter.update(line)
|
|
|
|
|
|
# 过滤掉低频词
|
2025-12-06 13:10:02 +08:00
|
|
|
|
tokens = ((token, count) for token, count in counter.items() if count >= settings.POETRY_MIN_WORD_FREQ)
|
2025-11-30 16:24:32 +08:00
|
|
|
|
# 按词频排序
|
|
|
|
|
|
tokens = sorted(tokens, key=lambda x: -x[1])
|
|
|
|
|
|
# 去掉词频,只保留词列表
|
|
|
|
|
|
tokens = list(token for token, _ in tokens)
|
|
|
|
|
|
|
|
|
|
|
|
# 将特殊词和数据集中的词拼接起来
|
|
|
|
|
|
tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens
|
|
|
|
|
|
# 创建词典 token->id映射关系
|
|
|
|
|
|
token_id_dict = dict(zip(tokens, range(len(tokens))))
|
|
|
|
|
|
# 使用新词典重新建立分词器
|
|
|
|
|
|
tokenizer = Tokenizer(token_id_dict)
|
|
|
|
|
|
# 直接返回,此处无需混洗数据
|
|
|
|
|
|
return tokenizer
|
|
|
|
|
|
|
2025-12-06 13:10:02 +08:00
|
|
|
|
class PoetryDataset(Dataset):
|
|
|
|
|
|
"""适配PyTorch的古诗词Dataset"""
|
|
|
|
|
|
|
|
|
|
|
|
preprocessor: PoetryPreprocessor
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, poetry: PoetryPreprocessor):
|
|
|
|
|
|
self.preprocessor = poetry
|
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
# 获取古诗词并编码
|
|
|
|
|
|
poetry = self.preprocessor.poetry[index]
|
|
|
|
|
|
encoded_poetry = self.preprocessor.tokenizer.encode(poetry)
|
|
|
|
|
|
# 直接返回编码后的古诗词数据,数据的padding和输入输出构成由DataLoader来做。
|
|
|
|
|
|
return encoded_poetry
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
|
return len(self.preprocessor.poetry)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PoetryDataLoader:
|
|
|
|
|
|
"""适配PyTorch的古诗词数据Loader"""
|
|
|
|
|
|
|
|
|
|
|
|
preprocessor: PoetryPreprocessor
|
|
|
|
|
|
dataset: PoetryDataset
|
|
|
|
|
|
loader: DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, batch_size: int, force_reclean: bool=False):
|
|
|
|
|
|
self.preprocessor = PoetryPreprocessor(force_reclean)
|
|
|
|
|
|
self.dataset = PoetryDataset(self.preprocessor)
|
|
|
|
|
|
self.loader = DataLoader(dataset=self.dataset,
|
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
|
# 对古诗词做padding后返回
|
|
|
|
|
|
collate_fn=lambda batch: self.__collect_fn(batch),
|
|
|
|
|
|
# 混洗数据以防止过拟合
|
|
|
|
|
|
shuffle=True)
|
|
|
|
|
|
|
|
|
|
|
|
def get_vocab_size(self) -> int:
|
|
|
|
|
|
"""一个便捷的获取vocab_size的函数,避免层层调用"""
|
|
|
|
|
|
return self.preprocessor.tokenizer.vocab_size
|
2025-12-06 19:56:55 +08:00
|
|
|
|
|
|
|
|
|
|
def get_tokenizer(self) -> Tokenizer:
|
|
|
|
|
|
"""一个便捷的获取Tokenizer的函数,避免层层调用"""
|
|
|
|
|
|
return self.preprocessor.tokenizer
|
2025-12-06 13:10:02 +08:00
|
|
|
|
|
|
|
|
|
|
def __collect_fn(self, batch: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
适用于DataLoader的样本收集器。
|
|
|
|
|
|
用于将上传的古诗词样本做padding后打包返回。
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 计算填充长度
|
|
|
|
|
|
length = max(map(len, batch))
|
|
|
|
|
|
# 获取填充数据
|
|
|
|
|
|
padding = self.preprocessor.tokenizer.token_to_id(TOKEN_PAD)
|
|
|
|
|
|
# 开始填充
|
|
|
|
|
|
padded_batch: list[list[int]] = []
|
|
|
|
|
|
for entry in batch:
|
|
|
|
|
|
padding_length = length - len(entry)
|
|
|
|
|
|
if padding_length > 0:
|
|
|
|
|
|
# 不足就进行填充
|
|
|
|
|
|
padded_batch.append(numpy.concatenate([entry, [padding] * padding_length]))
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 超过就进行截断
|
|
|
|
|
|
padded_batch.append(entry[:length])
|
|
|
|
|
|
numpy_batch = numpy.array(padded_batch)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成输入和输出。
|
|
|
|
|
|
# 输入是去除最后一个字符的部分,输出是去除第一个字符的部分。
|
|
|
|
|
|
# 这么做是为了让RNN从输入推到输出(下一个字符)。
|
|
|
|
|
|
# 此外,输出要做onehot编码
|
|
|
|
|
|
input = torch.tensor(numpy_batch[:, :-1], dtype=torch.long)
|
2025-12-06 20:48:27 +08:00
|
|
|
|
output = torch.tensor(numpy_batch[:, 1:], dtype=torch.long)
|
2025-12-06 13:10:02 +08:00
|
|
|
|
|
|
|
|
|
|
# 返回结果
|
|
|
|
|
|
return input, output
|
|
|
|
|
|
|