from pathlib import Path import typing import pickle from collections import Counter import numpy class Tokenizer: """分词器""" token_dict: dict[str, int] """词->编号的映射""" token_dict_rev: dict[int, str] """编号->词的映射""" vocab_size: int """词汇表大小""" def __init__(self, token_dict: dict[str, int]): self.token_dict = token_dict self.token_dict_rev = {value: key for key, value in self.token_dict.items()} self.vocab_size = len(self.token_dict) def id_to_token(self, token_id: int) -> str: """ 给定一个编号,查找词汇表中对应的词。 :param token_id: 带查找词的编号 :return: 编号对应的词 """ return self.token_dict_rev[token_id] def token_to_id(self, token: str): """ 给定一个词,查找它在词汇表中的编号。 未找到则返回低频词[UNK]的编号。 :param token: 带查找编号的词 :return: 词的编号 """ return self.token_dict.get(token, self.token_dict['[UNK]']) def encode(self, tokens: str) -> list[int]: """ 给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列 :param tokens: 待编码字符串 :return: 编号序列 """ # 加上开始标记 token_ids: list[int] = [self.token_to_id('[CLS]'), ] # 加入字符串编号序列 for token in tokens: token_ids.append(self.token_to_id(token)) # 加上结束标记 token_ids.append(self.token_to_id('[SEP]')) return token_ids def decode(self, token_ids: typing.Iterable[int]) -> str: """ 给定一个编号序列,将它解码成字符串 :param token_ids: 待解码的编号序列 :return: 解码出的字符串 """ # 起止标记字符特殊处理 spec_tokens = {'[CLS]', '[SEP]'} # 保存解码出的字符的list tokens: list[str] = [] for token_id in token_ids: token = self.id_to_token(token_id) if token in spec_tokens: continue tokens.append(token) # 拼接字符串 return ''.join(tokens) class PoetryDataset: """古诗词数据集加载器""" BAD_WORDS: typing.ClassVar[list[str]] = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']'] """禁用词,包含如下字符的唐诗将被忽略""" MAX_SEG_LEN: typing.ClassVar[int] = 64 """句子最大长度""" MIN_WORD_FREQ: typing.ClassVar[int] = 8 """最小词频""" tokenizer: Tokenizer """分词器""" poetry: list[str] """古诗词数据集,每一项是一首诗""" def __init__(self, force_reclean: bool = False): # 加载古诗,然后统计词频构建分词器 self.poetry = self.load_poetry(force_reclean) self.tokenizer = self.build_tokenizer(self.poetry) def load_poetry(self, force_reclean: bool = False) -> list[str]: """加载古诗词数据集""" if force_reclean or (not self.get_clean_dataset_path().is_file()): return self.load_poetry_from_dirty() else: return self.load_poetry_from_clean() def load_poetry_from_clean(self) -> list[str]: """直接读取清洗后的数据""" with open(self.get_clean_dataset_path(), 'rb') as f: return pickle.load(f) def load_poetry_from_dirty(self) -> list[str]: """从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据""" with open(self.get_dirty_dataset_path(), 'r', encoding='utf-8') as f: lines = f.readlines() # 将冒号统一成相同格式 lines = [line.replace(':', ':') for line in lines] # 数据集列表 poetry: list[str] = [] # 逐行处理读取到的数据 for line in lines: # 删除空白字符 line = line.strip() # 有且只能有一个冒号用来分割标题 if line.count(':') != 1: continue # 获取后半部分(删除标题) _, last_part = line.split(':') # 长度不能超过最大长度 if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2: continue # 不能包含禁止词 for bad_word in PoetryDataset.BAD_WORDS: if bad_word in last_part: break else: # 如果循环正常结束,就表明没有bad words,推入数据列表 poetry.append(last_part) # 数据清理完毕 # 写入干净数据 with open(self.get_clean_dataset_path(), 'wb') as f: pickle.dump(poetry, f) # 返回结果 return poetry def get_clean_dataset_path(self) -> Path: return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle' def get_dirty_dataset_path(self) -> Path: return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt' def build_tokenizer(self, poetry: list[str]) -> Tokenizer: """统计词频,并构建分词器""" # 统计词频 counter: Counter[str] = Counter() for line in poetry: counter.update(line) # 过滤掉低频词 tokens = ((token, count) for token, count in counter.items() if count >= PoetryDataset.MIN_WORD_FREQ) # 按词频排序 tokens = sorted(tokens, key=lambda x: -x[1]) # 去掉词频,只保留词列表 tokens = list(token for token, _ in tokens) # 将特殊词和数据集中的词拼接起来 tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens # 创建词典 token->id映射关系 token_id_dict = dict(zip(tokens, range(len(tokens)))) # 使用新词典重新建立分词器 tokenizer = Tokenizer(token_id_dict) # 直接返回,此处无需混洗数据 return tokenizer