1
0
Files
ai-school/exp3/modified/dataset.py
2025-12-06 13:10:02 +08:00

269 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
import typing
import pickle
from collections import Counter
import numpy
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import settings
TOKEN_PAD: str = '[PAD]'
"""使用古诗词数据时的特殊字符RNN填充时使用的填充字符。"""
TOKEN_UNK: str = '[UNK]'
"""使用古诗词数据时的特殊字符,词频不足的生僻字。"""
TOKEN_CLS: str = '[CLS]'
"""使用古诗词数据时的特殊字符,标记古诗词开始。"""
TOKEN_SEP: str = '[SEP]'
"""使用古诗词数据时的特殊字符,标记古诗词结束。"""
class Tokenizer:
"""分词器"""
token_dict: dict[str, int]
"""词->编号的映射"""
token_dict_rev: dict[int, str]
"""编号->词的映射"""
vocab_size: int
"""词汇表大小"""
def __init__(self, token_dict: dict[str, int]):
self.token_dict = token_dict
self.token_dict_rev = {value: key for key, value in self.token_dict.items()}
self.vocab_size = len(self.token_dict)
def id_to_token(self, token_id: int) -> str:
"""
给定一个编号,查找词汇表中对应的词。
:param token_id: 带查找词的编号
:return: 编号对应的词
"""
return self.token_dict_rev[token_id]
def token_to_id(self, token: str):
"""
给定一个词,查找它在词汇表中的编号。
未找到则返回低频词[UNK]的编号。
:param token: 带查找编号的词
:return: 词的编号
"""
return self.token_dict.get(token, self.token_dict['[UNK]'])
def encode(self, tokens: str) -> list[int]:
"""
给定一个字符串s在头尾分别加上标记开始和结束的特殊字符并将它转成对应的编号序列
:param tokens: 待编码字符串
:return: 编号序列
"""
# 加上开始标记
token_ids: list[int] = [self.token_to_id(TOKEN_CLS), ]
# 加入字符串编号序列
for token in tokens:
token_ids.append(self.token_to_id(token))
# 加上结束标记
token_ids.append(self.token_to_id(TOKEN_SEP))
return token_ids
def decode(self, token_ids: typing.Iterable[int]) -> str:
"""
给定一个编号序列,将它解码成字符串
:param token_ids: 待解码的编号序列
:return: 解码出的字符串
"""
# 起止标记字符特殊处理
spec_tokens = {TOKEN_CLS, TOKEN_SEP}
# 保存解码出的字符的list
tokens: list[str] = []
for token_id in token_ids:
token = self.id_to_token(token_id)
if token in spec_tokens:
continue
tokens.append(token)
# 拼接字符串
return ''.join(tokens)
class PoetryPreprocessor:
"""
古诗词数据集的预处理器。
该类负责古诗词数据的读取,清洗和数据持久化。
"""
tokenizer: Tokenizer
"""分词器"""
poetry: list[str]
"""古诗词数据集,每一项是一首诗"""
def __init__(self, force_reclean: bool=False):
# 加载古诗词数据集
if force_reclean or (not settings.CLEAN_DATASET_PATH.is_file()):
(self.poetry, self.tokenizer) = self.__load_from_dirty()
else:
(self.poetry, self.tokenizer) = self.__load_from_clean()
def __load_from_clean(self) -> tuple[list[str], Tokenizer]:
"""直接读取清洗后的数据"""
with open(settings.CLEAN_DATASET_PATH, 'rb') as f:
return pickle.load(f)
def __load_from_dirty(self) -> tuple[list[str], Tokenizer]:
"""从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据"""
# 加载脏的古诗数据
with open(settings.DIRTY_DATASET_PATH, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 清洗古诗数据
poetry = self.__wash_dirty_poetry(lines)
# 构建分词器
tokenizer = self.__build_tokenizer(poetry)
# 数据清理完毕
# 写入干净数据
with open(settings.CLEAN_DATASET_PATH, 'wb') as f:
pickle.dump((poetry, tokenizer), f)
# 返回结果
return poetry, tokenizer
def __wash_dirty_poetry(self, poetry: list[str]) -> list[str]:
"""
清洗给定的古诗数据。
:param poetry: 要清洗的古诗数据,每一行是一首古诗。
古诗开头是标题,然后是一个冒号(全角或半角),然后是古诗主体。
:return: 清洗完毕的古诗。
"""
# 禁用词列表,包含如下字符的诗歌将被忽略
BAD_WORDS = ['', '', '(', ')', '__', '', '', '', '', '[', ']']
# 数据集列表
clean_poetry: list[str] = []
# 逐行处理读取到的数据
for line in poetry:
# 删除空白字符
line = line.strip()
# 将全角冒号替换为半角的
line = line.replace('', ':')
# 有且只能有一个冒号用来分割标题
if line.count(':') != 1: continue
# 获取后半部分(删除标题)
_, last_part = line.split(':')
# 长度不能超过最大长度减去2是因为古诗首尾要加特殊符号
if len(last_part) > settings.POETRY_MAX_LEN - 2:
continue
# 不能包含禁止词
for bad_word in BAD_WORDS:
if bad_word in last_part:
break
else:
# 如果循环正常结束就表明没有bad words推入数据列表
clean_poetry.append(last_part)
# 返回清洗完毕的结果
return clean_poetry
def __build_tokenizer(self, poetry: list[str]) -> Tokenizer:
"""
根据给定古诗统计词频,并构建分词器。
:param poetry: 清洗完毕后的古诗,每一行是一句诗。
:return: 构建完毕的分词器。
"""
# 统计词频
counter: Counter[str] = Counter()
for line in poetry:
counter.update(line)
# 过滤掉低频词
tokens = ((token, count) for token, count in counter.items() if count >= settings.POETRY_MIN_WORD_FREQ)
# 按词频排序
tokens = sorted(tokens, key=lambda x: -x[1])
# 去掉词频,只保留词列表
tokens = list(token for token, _ in tokens)
# 将特殊词和数据集中的词拼接起来
tokens = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + tokens
# 创建词典 token->id映射关系
token_id_dict = dict(zip(tokens, range(len(tokens))))
# 使用新词典重新建立分词器
tokenizer = Tokenizer(token_id_dict)
# 直接返回,此处无需混洗数据
return tokenizer
class PoetryDataset(Dataset):
"""适配PyTorch的古诗词Dataset"""
preprocessor: PoetryPreprocessor
def __init__(self, poetry: PoetryPreprocessor):
self.preprocessor = poetry
def __getitem__(self, index):
# 获取古诗词并编码
poetry = self.preprocessor.poetry[index]
encoded_poetry = self.preprocessor.tokenizer.encode(poetry)
# 直接返回编码后的古诗词数据数据的padding和输入输出构成由DataLoader来做。
return encoded_poetry
def __len__(self):
return len(self.preprocessor.poetry)
class PoetryDataLoader:
"""适配PyTorch的古诗词数据Loader"""
preprocessor: PoetryPreprocessor
dataset: PoetryDataset
loader: DataLoader
def __init__(self, batch_size: int, force_reclean: bool=False):
self.preprocessor = PoetryPreprocessor(force_reclean)
self.dataset = PoetryDataset(self.preprocessor)
self.loader = DataLoader(dataset=self.dataset,
batch_size=batch_size,
# 对古诗词做padding后返回
collate_fn=lambda batch: self.__collect_fn(batch),
# 混洗数据以防止过拟合
shuffle=True)
def get_vocab_size(self) -> int:
"""一个便捷的获取vocab_size的函数避免层层调用"""
return self.preprocessor.tokenizer.vocab_size
def __collect_fn(self, batch: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
"""
适用于DataLoader的样本收集器。
用于将上传的古诗词样本做padding后打包返回。
"""
# 计算填充长度
length = max(map(len, batch))
# 获取填充数据
padding = self.preprocessor.tokenizer.token_to_id(TOKEN_PAD)
# 开始填充
padded_batch: list[list[int]] = []
for entry in batch:
padding_length = length - len(entry)
if padding_length > 0:
# 不足就进行填充
padded_batch.append(numpy.concatenate([entry, [padding] * padding_length]))
else:
# 超过就进行截断
padded_batch.append(entry[:length])
numpy_batch = numpy.array(padded_batch)
# 生成输入和输出。
# 输入是去除最后一个字符的部分,输出是去除第一个字符的部分。
# 这么做是为了让RNN从输入推到输出下一个字符
# 此外输出要做onehot编码
input = torch.tensor(numpy_batch[:, :-1], dtype=torch.long)
output = F.one_hot(torch.tensor(numpy_batch[:, 1:], dtype=torch.long),
num_classes=self.preprocessor.tokenizer.vocab_size).float()
# 返回结果
return input, output