1
0
Files
ai-school/exp3/modified/dataset.py

172 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import typing
import pickle
from collections import Counter
import numpy
class Tokenizer:
"""分词器"""
token_dict: dict[str, int]
"""词->编号的映射"""
token_dict_rev: dict[int, str]
"""编号->词的映射"""
vocab_size: int
"""词汇表大小"""
def __init__(self, token_dict: dict[str, int]):
self.token_dict = token_dict
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
self.vocab_size = len(self.token_dict)
def id_to_token(self, token_id: int) -> str:
"""
给定一个编号,查找词汇表中对应的词。
:param token_id: 带查找词的编号
:return: 编号对应的词
"""
return self.token_dict_rev[token_id]
def token_to_id(self, token: str):
"""
给定一个词,查找它在词汇表中的编号。
未找到则返回低频词[UNK]的编号。
:param token: 带查找编号的词
:return: 词的编号
"""
return self.token_dict.get(token, self.token_dict['[UNK]'])
def encode(self, tokens: str) -> list[int]:
"""
给定一个字符串s在头尾分别加上标记开始和结束的特殊字符并将它转成对应的编号序列
:param tokens: 待编码字符串
:return: 编号序列
"""
# 加上开始标记
token_ids: list[int] = [self.token_to_id('[CLS]'), ]
# 加入字符串编号序列
for token in tokens:
token_ids.append(self.token_to_id(token))
# 加上结束标记
token_ids.append(self.token_to_id('[SEP]'))
return token_ids
def decode(self, token_ids: typing.Iterable[int]) -> str:
"""
给定一个编号序列,将它解码成字符串
:param token_ids: 待解码的编号序列
:return: 解码出的字符串
"""
# 起止标记字符特殊处理
spec_tokens = {'[CLS]', '[SEP]'}
# 保存解码出的字符的list
tokens: list[str] = []
for token_id in token_ids:
token = self.id_to_token(token_id)
if token in spec_tokens:
continue
tokens.append(token)
# 拼接字符串
return ''.join(tokens)
class PoetryDataset:
"""古诗词数据集加载器"""
BAD_WORDS: typing.ClassVar[list[str]] = ['', '', '(', ')', '__', '', '', '', '', '[', ']']
"""禁用词,包含如下字符的唐诗将被忽略"""
MAX_SEG_LEN: typing.ClassVar[int] = 64
"""句子最大长度"""
MIN_WORD_FREQ: typing.ClassVar[int] = 8
"""最小词频"""
tokenizer: Tokenizer
"""分词器"""
poetry: list[str]
"""古诗词数据集"""
def __init__(self, force_reclean: bool = False):
# 加载古诗,然后统计词频构建分词器
self.poetry = self.load_poetry(force_reclean)
self.tokenizer = self.build_tokenizer(self.poetry)
def load_poetry(self, force_reclean: bool = False) -> list[str]:
"""加载古诗词数据集"""
if force_reclean or (not self.get_clean_dataset_path().is_file()):
return self.load_poetry_from_dirty()
else:
return self.load_poetry_from_clean()
def load_poetry_from_clean(self) -> list[str]:
"""直接读取清洗后的数据"""
with open(self.get_clean_dataset_path(), 'rb') as f:
return pickle.load(f)
def load_poetry_from_dirty(self) -> list[str]:
"""从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据"""
with open(self.get_dirty_dataset_path(), 'r', encoding='utf-8') as f:
lines = f.readlines()
# 将冒号统一成相同格式
lines = [line.replace('', ':') for line in lines]
# 数据集列表
poetry: list[str] = []
# 逐行处理读取到的数据
for line in lines:
# 删除空白字符
line = line.strip()
# 有且只能有一个冒号用来分割标题
if line.count(':') != 1: continue
# 获取后半部分
_, last_part = line.split(':')
# 长度不能超过最大长度
if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2:
continue
# 不能包含禁止词
for bad_word in PoetryDataset.BAD_WORDS:
if bad_word in last_part:
break
else:
# 如果循环正常结束就表明没有bad words推入数据列表
poetry.append(last_part)
# 数据清理完毕
# 写入干净数据
with open(self.get_clean_dataset_path(), 'wb') as f:
pickle.dump(poetry, f)
# 返回结果
return poetry
def get_clean_dataset_path(self) -> Path:
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
def get_dirty_dataset_path(self) -> Path:
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
def build_tokenizer(self, poetry: list[str]) -> Tokenizer:
"""统计词频,并构建分词器"""
# 统计词频
counter: Counter[str] = Counter()
for line in poetry:
counter.update(line)
# 过滤掉低频词
tokens = ((token, count) for token, count in counter.items() if count >= PoetryDataset.MIN_WORD_FREQ)
# 按词频排序
tokens = sorted(tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
tokens = list(token for token, _ in tokens)
# 将特殊词和数据集中的词拼接起来
tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(tokens, range(len(tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 直接返回,此处无需混洗数据
return tokenizer