172 lines
6.0 KiB
Python
172 lines
6.0 KiB
Python
from pathlib import Path
|
||
import typing
|
||
import pickle
|
||
from collections import Counter
|
||
import numpy
|
||
|
||
class Tokenizer:
|
||
"""分词器"""
|
||
|
||
token_dict: dict[str, int]
|
||
"""词->编号的映射"""
|
||
token_dict_rev: dict[int, str]
|
||
"""编号->词的映射"""
|
||
vocab_size: int
|
||
"""词汇表大小"""
|
||
|
||
def __init__(self, token_dict: dict[str, int]):
|
||
self.token_dict = token_dict
|
||
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
|
||
self.vocab_size = len(self.token_dict)
|
||
|
||
def id_to_token(self, token_id: int) -> str:
|
||
"""
|
||
给定一个编号,查找词汇表中对应的词。
|
||
|
||
:param token_id: 带查找词的编号
|
||
:return: 编号对应的词
|
||
"""
|
||
return self.token_dict_rev[token_id]
|
||
|
||
def token_to_id(self, token: str):
|
||
"""
|
||
给定一个词,查找它在词汇表中的编号。
|
||
未找到则返回低频词[UNK]的编号。
|
||
|
||
:param token: 带查找编号的词
|
||
:return: 词的编号
|
||
"""
|
||
return self.token_dict.get(token, self.token_dict['[UNK]'])
|
||
|
||
def encode(self, tokens: str) -> list[int]:
|
||
"""
|
||
给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
|
||
|
||
:param tokens: 待编码字符串
|
||
:return: 编号序列
|
||
"""
|
||
# 加上开始标记
|
||
token_ids: list[int] = [self.token_to_id('[CLS]'), ]
|
||
# 加入字符串编号序列
|
||
for token in tokens:
|
||
token_ids.append(self.token_to_id(token))
|
||
# 加上结束标记
|
||
token_ids.append(self.token_to_id('[SEP]'))
|
||
return token_ids
|
||
|
||
def decode(self, token_ids: typing.Iterable[int]) -> str:
|
||
"""
|
||
给定一个编号序列,将它解码成字符串
|
||
|
||
:param token_ids: 待解码的编号序列
|
||
:return: 解码出的字符串
|
||
"""
|
||
# 起止标记字符特殊处理
|
||
spec_tokens = {'[CLS]', '[SEP]'}
|
||
# 保存解码出的字符的list
|
||
tokens: list[str] = []
|
||
for token_id in token_ids:
|
||
token = self.id_to_token(token_id)
|
||
if token in spec_tokens:
|
||
continue
|
||
tokens.append(token)
|
||
# 拼接字符串
|
||
return ''.join(tokens)
|
||
|
||
|
||
class PoetryDataset:
|
||
"""古诗词数据集加载器"""
|
||
|
||
BAD_WORDS: typing.ClassVar[list[str]] = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
|
||
"""禁用词,包含如下字符的唐诗将被忽略"""
|
||
MAX_SEG_LEN: typing.ClassVar[int] = 64
|
||
"""句子最大长度"""
|
||
MIN_WORD_FREQ: typing.ClassVar[int] = 8
|
||
"""最小词频"""
|
||
|
||
tokenizer: Tokenizer
|
||
"""分词器"""
|
||
poetry: list[str]
|
||
"""古诗词数据集,每一项是一首诗"""
|
||
|
||
def __init__(self, force_reclean: bool = False):
|
||
# 加载古诗,然后统计词频构建分词器
|
||
self.poetry = self.load_poetry(force_reclean)
|
||
self.tokenizer = self.build_tokenizer(self.poetry)
|
||
|
||
def load_poetry(self, force_reclean: bool = False) -> list[str]:
|
||
"""加载古诗词数据集"""
|
||
if force_reclean or (not self.get_clean_dataset_path().is_file()):
|
||
return self.load_poetry_from_dirty()
|
||
else:
|
||
return self.load_poetry_from_clean()
|
||
|
||
def load_poetry_from_clean(self) -> list[str]:
|
||
"""直接读取清洗后的数据"""
|
||
with open(self.get_clean_dataset_path(), 'rb') as f:
|
||
return pickle.load(f)
|
||
|
||
def load_poetry_from_dirty(self) -> list[str]:
|
||
"""从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据"""
|
||
with open(self.get_dirty_dataset_path(), 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
# 将冒号统一成相同格式
|
||
lines = [line.replace(':', ':') for line in lines]
|
||
|
||
# 数据集列表
|
||
poetry: list[str] = []
|
||
# 逐行处理读取到的数据
|
||
for line in lines:
|
||
# 删除空白字符
|
||
line = line.strip()
|
||
# 有且只能有一个冒号用来分割标题
|
||
if line.count(':') != 1: continue
|
||
# 获取后半部分(删除标题)
|
||
_, last_part = line.split(':')
|
||
# 长度不能超过最大长度
|
||
if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2:
|
||
continue
|
||
# 不能包含禁止词
|
||
for bad_word in PoetryDataset.BAD_WORDS:
|
||
if bad_word in last_part:
|
||
break
|
||
else:
|
||
# 如果循环正常结束,就表明没有bad words,推入数据列表
|
||
poetry.append(last_part)
|
||
|
||
# 数据清理完毕
|
||
# 写入干净数据
|
||
with open(self.get_clean_dataset_path(), 'wb') as f:
|
||
pickle.dump(poetry, f)
|
||
# 返回结果
|
||
return poetry
|
||
|
||
def get_clean_dataset_path(self) -> Path:
|
||
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
|
||
|
||
def get_dirty_dataset_path(self) -> Path:
|
||
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
|
||
|
||
def build_tokenizer(self, poetry: list[str]) -> Tokenizer:
|
||
"""统计词频,并构建分词器"""
|
||
# 统计词频
|
||
counter: Counter[str] = Counter()
|
||
for line in poetry:
|
||
counter.update(line)
|
||
# 过滤掉低频词
|
||
tokens = ((token, count) for token, count in counter.items() if count >= PoetryDataset.MIN_WORD_FREQ)
|
||
# 按词频排序
|
||
tokens = sorted(tokens, key=lambda x: -x[1])
|
||
# 去掉词频,只保留词列表
|
||
tokens = list(token for token, _ in tokens)
|
||
|
||
# 将特殊词和数据集中的词拼接起来
|
||
tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens
|
||
# 创建词典 token->id映射关系
|
||
token_id_dict = dict(zip(tokens, range(len(tokens))))
|
||
# 使用新词典重新建立分词器
|
||
tokenizer = Tokenizer(token_id_dict)
|
||
# 直接返回,此处无需混洗数据
|
||
return tokenizer
|
||
|