from pathlib import Path import typing import pickle from collections import Counter import numpy import torch from torch.utils.data import DataLoader, Dataset import torch.nn.functional as F import settings TOKEN_PAD: str = '[PAD]' """使用古诗词数据时的特殊字符,RNN填充时使用的填充字符。""" TOKEN_UNK: str = '[UNK]' """使用古诗词数据时的特殊字符,词频不足的生僻字。""" TOKEN_CLS: str = '[CLS]' """使用古诗词数据时的特殊字符,标记古诗词开始。""" TOKEN_SEP: str = '[SEP]' """使用古诗词数据时的特殊字符,标记古诗词结束。""" class Tokenizer: """分词器""" token_dict: dict[str, int] """词->编号的映射""" token_dict_rev: dict[int, str] """编号->词的映射""" vocab_size: int """词汇表大小""" def __init__(self, token_dict: dict[str, int]): self.token_dict = token_dict self.token_dict_rev = {value: key for key, value in self.token_dict.items()} self.vocab_size = len(self.token_dict) def id_to_token(self, token_id: int) -> str: """ 给定一个编号,查找词汇表中对应的词。 :param token_id: 带查找词的编号 :return: 编号对应的词 """ return self.token_dict_rev[token_id] def token_to_id(self, token: str): """ 给定一个词,查找它在词汇表中的编号。 未找到则返回低频词[UNK]的编号。 :param token: 带查找编号的词 :return: 词的编号 """ return self.token_dict.get(token, self.token_dict['[UNK]']) def encode(self, tokens: str) -> list[int]: """ 给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列 :param tokens: 待编码字符串 :return: 编号序列 """ # 加上开始标记 token_ids: list[int] = [self.token_to_id(TOKEN_CLS), ] # 加入字符串编号序列 for token in tokens: token_ids.append(self.token_to_id(token)) # 加上结束标记 token_ids.append(self.token_to_id(TOKEN_SEP)) return token_ids def decode(self, token_ids: typing.Iterable[int]) -> str: """ 给定一个编号序列,将它解码成字符串 :param token_ids: 待解码的编号序列 :return: 解码出的字符串 """ # 起止标记字符特殊处理 spec_tokens = {TOKEN_CLS, TOKEN_SEP} # 保存解码出的字符的list tokens: list[str] = [] for token_id in token_ids: token = self.id_to_token(token_id) if token in spec_tokens: continue tokens.append(token) # 拼接字符串 return ''.join(tokens) class PoetryPreprocessor: """ 古诗词数据集的预处理器。 该类负责古诗词数据的读取,清洗和数据持久化。 """ tokenizer: Tokenizer """分词器""" poetry: list[str] """古诗词数据集,每一项是一首诗""" def __init__(self, force_reclean: bool=False): # 加载古诗词数据集 if force_reclean or (not settings.CLEAN_DATASET_PATH.is_file()): (self.poetry, self.tokenizer) = self.__load_from_dirty() else: (self.poetry, self.tokenizer) = self.__load_from_clean() def __load_from_clean(self) -> tuple[list[str], Tokenizer]: """直接读取清洗后的数据""" with open(settings.CLEAN_DATASET_PATH, 'rb') as f: return pickle.load(f) def __load_from_dirty(self) -> tuple[list[str], Tokenizer]: """从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据""" # 加载脏的古诗数据 with open(settings.DIRTY_DATASET_PATH, 'r', encoding='utf-8') as f: lines = f.readlines() # 清洗古诗数据 poetry = self.__wash_dirty_poetry(lines) # 构建分词器 tokenizer = self.__build_tokenizer(poetry) # 数据清理完毕 # 写入干净数据 with open(settings.CLEAN_DATASET_PATH, 'wb') as f: pickle.dump((poetry, tokenizer), f) # 返回结果 return poetry, tokenizer def __wash_dirty_poetry(self, poetry: list[str]) -> list[str]: """ 清洗给定的古诗数据。 :param poetry: 要清洗的古诗数据,每一行是一首古诗。 古诗开头是标题,然后是一个冒号(全角或半角),然后是古诗主体。 :return: 清洗完毕的古诗。 """ # 禁用词列表,包含如下字符的诗歌将被忽略 BAD_WORDS = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']'] # 数据集列表 clean_poetry: list[str] = [] # 逐行处理读取到的数据 for line in poetry: # 删除空白字符 line = line.strip() # 将全角冒号替换为半角的 line = line.replace(':', ':') # 有且只能有一个冒号用来分割标题 if line.count(':') != 1: continue # 获取后半部分(删除标题) _, last_part = line.split(':') # 长度不能超过最大长度(减去2是因为古诗首尾要加特殊符号) if len(last_part) > settings.POETRY_MAX_LEN - 2: continue # 不能包含禁止词 for bad_word in BAD_WORDS: if bad_word in last_part: break else: # 如果循环正常结束,就表明没有bad words,推入数据列表 clean_poetry.append(last_part) # 返回清洗完毕的结果 return clean_poetry def __build_tokenizer(self, poetry: list[str]) -> Tokenizer: """ 根据给定古诗统计词频,并构建分词器。 :param poetry: 清洗完毕后的古诗,每一行是一句诗。 :return: 构建完毕的分词器。 """ # 统计词频 counter: Counter[str] = Counter() for line in poetry: counter.update(line) # 过滤掉低频词 tokens = ((token, count) for token, count in counter.items() if count >= settings.POETRY_MIN_WORD_FREQ) # 按词频排序 tokens = sorted(tokens, key=lambda x: -x[1]) # 去掉词频,只保留词列表 tokens = list(token for token, _ in tokens) # 将特殊词和数据集中的词拼接起来 tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens # 创建词典 token->id映射关系 token_id_dict = dict(zip(tokens, range(len(tokens)))) # 使用新词典重新建立分词器 tokenizer = Tokenizer(token_id_dict) # 直接返回,此处无需混洗数据 return tokenizer class PoetryDataset(Dataset): """适配PyTorch的古诗词Dataset""" preprocessor: PoetryPreprocessor def __init__(self, poetry: PoetryPreprocessor): self.preprocessor = poetry def __getitem__(self, index): # 获取古诗词并编码 poetry = self.preprocessor.poetry[index] encoded_poetry = self.preprocessor.tokenizer.encode(poetry) # 直接返回编码后的古诗词数据,数据的padding和输入输出构成由DataLoader来做。 return encoded_poetry def __len__(self): return len(self.preprocessor.poetry) class PoetryDataLoader: """适配PyTorch的古诗词数据Loader""" preprocessor: PoetryPreprocessor dataset: PoetryDataset loader: DataLoader def __init__(self, batch_size: int, force_reclean: bool=False): self.preprocessor = PoetryPreprocessor(force_reclean) self.dataset = PoetryDataset(self.preprocessor) self.loader = DataLoader(dataset=self.dataset, batch_size=batch_size, # 对古诗词做padding后返回 collate_fn=lambda batch: self.__collect_fn(batch), # 混洗数据以防止过拟合 shuffle=True) def get_vocab_size(self) -> int: """一个便捷的获取vocab_size的函数,避免层层调用""" return self.preprocessor.tokenizer.vocab_size def get_tokenizer(self) -> Tokenizer: """一个便捷的获取Tokenizer的函数,避免层层调用""" return self.preprocessor.tokenizer def __collect_fn(self, batch: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]: """ 适用于DataLoader的样本收集器。 用于将上传的古诗词样本做padding后打包返回。 """ # 计算填充长度 length = max(map(len, batch)) # 获取填充数据 padding = self.preprocessor.tokenizer.token_to_id(TOKEN_PAD) # 开始填充 padded_batch: list[list[int]] = [] for entry in batch: padding_length = length - len(entry) if padding_length > 0: # 不足就进行填充 padded_batch.append(numpy.concatenate([entry, [padding] * padding_length])) else: # 超过就进行截断 padded_batch.append(entry[:length]) numpy_batch = numpy.array(padded_batch) # 生成输入和输出。 # 输入是去除最后一个字符的部分,输出是去除第一个字符的部分。 # 这么做是为了让RNN从输入推到输出(下一个字符)。 # 此外,输出要做onehot编码 input = torch.tensor(numpy_batch[:, :-1], dtype=torch.long) output = F.one_hot(torch.tensor(numpy_batch[:, 1:], dtype=torch.long), num_classes=self.preprocessor.tokenizer.vocab_size).float() # 返回结果 return input, output