update for change of exp2 and add exp3
This commit is contained in:
171
exp3/modified/dataset.py
Normal file
171
exp3/modified/dataset.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from pathlib import Path
|
||||
import typing
|
||||
import pickle
|
||||
from collections import Counter
|
||||
import numpy
|
||||
|
||||
class Tokenizer:
|
||||
"""分词器"""
|
||||
|
||||
token_dict: dict[str, int]
|
||||
"""词->编号的映射"""
|
||||
token_dict_rev: dict[int, str]
|
||||
"""编号->词的映射"""
|
||||
vocab_size: int
|
||||
"""词汇表大小"""
|
||||
|
||||
def __init__(self, token_dict: dict[str, int]):
|
||||
self.token_dict = token_dict
|
||||
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
|
||||
self.vocab_size = len(self.token_dict)
|
||||
|
||||
def id_to_token(self, token_id: int) -> str:
|
||||
"""
|
||||
给定一个编号,查找词汇表中对应的词。
|
||||
|
||||
:param token_id: 带查找词的编号
|
||||
:return: 编号对应的词
|
||||
"""
|
||||
return self.token_dict_rev[token_id]
|
||||
|
||||
def token_to_id(self, token: str):
|
||||
"""
|
||||
给定一个词,查找它在词汇表中的编号。
|
||||
未找到则返回低频词[UNK]的编号。
|
||||
|
||||
:param token: 带查找编号的词
|
||||
:return: 词的编号
|
||||
"""
|
||||
return self.token_dict.get(token, self.token_dict['[UNK]'])
|
||||
|
||||
def encode(self, tokens: str) -> list[int]:
|
||||
"""
|
||||
给定一个字符串s,在头尾分别加上标记开始和结束的特殊字符,并将它转成对应的编号序列
|
||||
|
||||
:param tokens: 待编码字符串
|
||||
:return: 编号序列
|
||||
"""
|
||||
# 加上开始标记
|
||||
token_ids: list[int] = [self.token_to_id('[CLS]'), ]
|
||||
# 加入字符串编号序列
|
||||
for token in tokens:
|
||||
token_ids.append(self.token_to_id(token))
|
||||
# 加上结束标记
|
||||
token_ids.append(self.token_to_id('[SEP]'))
|
||||
return token_ids
|
||||
|
||||
def decode(self, token_ids: typing.Iterable[int]) -> str:
|
||||
"""
|
||||
给定一个编号序列,将它解码成字符串
|
||||
|
||||
:param token_ids: 待解码的编号序列
|
||||
:return: 解码出的字符串
|
||||
"""
|
||||
# 起止标记字符特殊处理
|
||||
spec_tokens = {'[CLS]', '[SEP]'}
|
||||
# 保存解码出的字符的list
|
||||
tokens: list[str] = []
|
||||
for token_id in token_ids:
|
||||
token = self.id_to_token(token_id)
|
||||
if token in spec_tokens:
|
||||
continue
|
||||
tokens.append(token)
|
||||
# 拼接字符串
|
||||
return ''.join(tokens)
|
||||
|
||||
|
||||
class PoetryDataset:
|
||||
"""古诗词数据集加载器"""
|
||||
|
||||
BAD_WORDS: typing.ClassVar[list[str]] = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
|
||||
"""禁用词,包含如下字符的唐诗将被忽略"""
|
||||
MAX_SEG_LEN: typing.ClassVar[int] = 64
|
||||
"""句子最大长度"""
|
||||
MIN_WORD_FREQ: typing.ClassVar[int] = 8
|
||||
"""最小词频"""
|
||||
|
||||
tokenizer: Tokenizer
|
||||
"""分词器"""
|
||||
poetry: list[str]
|
||||
"""古诗词数据集"""
|
||||
|
||||
def __init__(self, force_reclean: bool = False):
|
||||
# 加载古诗,然后统计词频构建分词器
|
||||
self.poetry = self.load_poetry(force_reclean)
|
||||
self.tokenizer = self.build_tokenizer(self.poetry)
|
||||
|
||||
def load_poetry(self, force_reclean: bool = False) -> list[str]:
|
||||
"""加载古诗词数据集"""
|
||||
if force_reclean or (not self.get_clean_dataset_path().is_file()):
|
||||
return self.load_poetry_from_dirty()
|
||||
else:
|
||||
return self.load_poetry_from_clean()
|
||||
|
||||
def load_poetry_from_clean(self) -> list[str]:
|
||||
"""直接读取清洗后的数据"""
|
||||
with open(self.get_clean_dataset_path(), 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def load_poetry_from_dirty(self) -> list[str]:
|
||||
"""从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据"""
|
||||
with open(self.get_dirty_dataset_path(), 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
# 将冒号统一成相同格式
|
||||
lines = [line.replace(':', ':') for line in lines]
|
||||
|
||||
# 数据集列表
|
||||
poetry: list[str] = []
|
||||
# 逐行处理读取到的数据
|
||||
for line in lines:
|
||||
# 删除空白字符
|
||||
line = line.strip()
|
||||
# 有且只能有一个冒号用来分割标题
|
||||
if line.count(':') != 1: continue
|
||||
# 获取后半部分
|
||||
_, last_part = line.split(':')
|
||||
# 长度不能超过最大长度
|
||||
if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2:
|
||||
continue
|
||||
# 不能包含禁止词
|
||||
for bad_word in PoetryDataset.BAD_WORDS:
|
||||
if bad_word in last_part:
|
||||
break
|
||||
else:
|
||||
# 如果循环正常结束,就表明没有bad words,推入数据列表
|
||||
poetry.append(last_part)
|
||||
|
||||
# 数据清理完毕
|
||||
# 写入干净数据
|
||||
with open(self.get_clean_dataset_path(), 'wb') as f:
|
||||
pickle.dump(poetry, f)
|
||||
# 返回结果
|
||||
return poetry
|
||||
|
||||
def get_clean_dataset_path(self) -> Path:
|
||||
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
|
||||
|
||||
def get_dirty_dataset_path(self) -> Path:
|
||||
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
|
||||
|
||||
def build_tokenizer(self, poetry: list[str]) -> Tokenizer:
|
||||
"""统计词频,并构建分词器"""
|
||||
# 统计词频
|
||||
counter: Counter[str] = Counter()
|
||||
for line in poetry:
|
||||
counter.update(line)
|
||||
# 过滤掉低频词
|
||||
tokens = ((token, count) for token, count in counter.items() if count >= PoetryDataset.MIN_WORD_FREQ)
|
||||
# 按词频排序
|
||||
tokens = sorted(tokens, key=lambda x: -x[1])
|
||||
# 去掉词频,只保留词列表
|
||||
tokens = list(token for token, _ in tokens)
|
||||
|
||||
# 将特殊词和数据集中的词拼接起来
|
||||
tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens
|
||||
# 创建词典 token->id映射关系
|
||||
token_id_dict = dict(zip(tokens, range(len(tokens))))
|
||||
# 使用新词典重新建立分词器
|
||||
tokenizer = Tokenizer(token_id_dict)
|
||||
# 直接返回,此处无需混洗数据
|
||||
return tokenizer
|
||||
|
||||
Reference in New Issue
Block a user