make exp3 works but no check
This commit is contained in:
@@ -66,8 +66,8 @@ def main():
|
||||
|
||||
for t in range(2000):
|
||||
optimizer.zero_grad() #清空上一步的残余更新参数值
|
||||
prediction: torch.Tensor = net(test_data.x) #喂给net训练数据x,输出预测值
|
||||
loss: torch.Tensor = loss_func(prediction, test_data.y) #计算两者的误差
|
||||
prediction: torch.tensor = net(test_data.x) #喂给net训练数据x,输出预测值
|
||||
loss: torch.tensor = loss_func(prediction, test_data.y) #计算两者的误差
|
||||
loss.backward() #误差反向传播,计算参数更新值
|
||||
optimizer.step() #将参数更新值施加到net的parameters上
|
||||
|
||||
|
||||
@@ -20,19 +20,6 @@ class Cnn(torch.nn.Module):
|
||||
self.fc1 = torch.nn.Linear(64 * 3 * 3, 64)
|
||||
self.fc2 = torch.nn.Linear(64, 10)
|
||||
|
||||
# 初始化模型参数
|
||||
self.__initialize_weights()
|
||||
|
||||
def __initialize_weights(self):
|
||||
# YYC MARK:
|
||||
# 把两个全连接线性层按tensorflow默认设置初始化,即:
|
||||
# - kernel_initializer='glorot_uniform'
|
||||
# - bias_initializer='zeros'
|
||||
torch.nn.init.xavier_normal_(self.fc1.weight)
|
||||
torch.nn.init.zeros_(self.fc1.bias)
|
||||
torch.nn.init.xavier_normal_(self.fc2.weight)
|
||||
torch.nn.init.zeros_(self.fc2.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = self.pool1(x)
|
||||
|
||||
@@ -83,7 +83,7 @@ class Predictor:
|
||||
:param image: 该列表的shape必须为28x28。
|
||||
:return: 预测结果。
|
||||
"""
|
||||
input = torch.Tensor(image).float()
|
||||
input = torch.tensor(image, dtype=torch.float32)
|
||||
assert(input.dim() == 2)
|
||||
assert(input.size(0) == 28)
|
||||
assert(input.size(1) == 28)
|
||||
|
||||
1
exp3/datasets/.gitignore
vendored
1
exp3/datasets/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
# Ignore datasets and processed datasets
|
||||
*.txt
|
||||
*.pickle
|
||||
@@ -3,6 +3,19 @@ import typing
|
||||
import pickle
|
||||
from collections import Counter
|
||||
import numpy
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import torch.nn.functional as F
|
||||
import settings
|
||||
|
||||
TOKEN_PAD: str = '[PAD]'
|
||||
"""使用古诗词数据时的特殊字符,RNN填充时使用的填充字符。"""
|
||||
TOKEN_UNK: str = '[UNK]'
|
||||
"""使用古诗词数据时的特殊字符,词频不足的生僻字。"""
|
||||
TOKEN_CLS: str = '[CLS]'
|
||||
"""使用古诗词数据时的特殊字符,标记古诗词开始。"""
|
||||
TOKEN_SEP: str = '[SEP]'
|
||||
"""使用古诗词数据时的特殊字符,标记古诗词结束。"""
|
||||
|
||||
class Tokenizer:
|
||||
"""分词器"""
|
||||
@@ -46,12 +59,12 @@ class Tokenizer:
|
||||
:return: 编号序列
|
||||
"""
|
||||
# 加上开始标记
|
||||
token_ids: list[int] = [self.token_to_id('[CLS]'), ]
|
||||
token_ids: list[int] = [self.token_to_id(TOKEN_CLS), ]
|
||||
# 加入字符串编号序列
|
||||
for token in tokens:
|
||||
token_ids.append(self.token_to_id(token))
|
||||
# 加上结束标记
|
||||
token_ids.append(self.token_to_id('[SEP]'))
|
||||
token_ids.append(self.token_to_id(TOKEN_SEP))
|
||||
return token_ids
|
||||
|
||||
def decode(self, token_ids: typing.Iterable[int]) -> str:
|
||||
@@ -62,7 +75,7 @@ class Tokenizer:
|
||||
:return: 解码出的字符串
|
||||
"""
|
||||
# 起止标记字符特殊处理
|
||||
spec_tokens = {'[CLS]', '[SEP]'}
|
||||
spec_tokens = {TOKEN_CLS, TOKEN_SEP}
|
||||
# 保存解码出的字符的list
|
||||
tokens: list[str] = []
|
||||
for token_id in token_ids:
|
||||
@@ -74,87 +87,99 @@ class Tokenizer:
|
||||
return ''.join(tokens)
|
||||
|
||||
|
||||
class PoetryDataset:
|
||||
"""古诗词数据集加载器"""
|
||||
class PoetryPreprocessor:
|
||||
"""
|
||||
古诗词数据集的预处理器。
|
||||
|
||||
BAD_WORDS: typing.ClassVar[list[str]] = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
|
||||
"""禁用词,包含如下字符的唐诗将被忽略"""
|
||||
MAX_SEG_LEN: typing.ClassVar[int] = 64
|
||||
"""句子最大长度"""
|
||||
MIN_WORD_FREQ: typing.ClassVar[int] = 8
|
||||
"""最小词频"""
|
||||
该类负责古诗词数据的读取,清洗和数据持久化。
|
||||
"""
|
||||
|
||||
tokenizer: Tokenizer
|
||||
"""分词器"""
|
||||
poetry: list[str]
|
||||
"""古诗词数据集,每一项是一首诗"""
|
||||
|
||||
def __init__(self, force_reclean: bool = False):
|
||||
# 加载古诗,然后统计词频构建分词器
|
||||
self.poetry = self.load_poetry(force_reclean)
|
||||
self.tokenizer = self.build_tokenizer(self.poetry)
|
||||
|
||||
def load_poetry(self, force_reclean: bool = False) -> list[str]:
|
||||
"""加载古诗词数据集"""
|
||||
if force_reclean or (not self.get_clean_dataset_path().is_file()):
|
||||
return self.load_poetry_from_dirty()
|
||||
def __init__(self, force_reclean: bool=False):
|
||||
# 加载古诗词数据集
|
||||
if force_reclean or (not settings.CLEAN_DATASET_PATH.is_file()):
|
||||
(self.poetry, self.tokenizer) = self.__load_from_dirty()
|
||||
else:
|
||||
return self.load_poetry_from_clean()
|
||||
(self.poetry, self.tokenizer) = self.__load_from_clean()
|
||||
|
||||
def load_poetry_from_clean(self) -> list[str]:
|
||||
def __load_from_clean(self) -> tuple[list[str], Tokenizer]:
|
||||
"""直接读取清洗后的数据"""
|
||||
with open(self.get_clean_dataset_path(), 'rb') as f:
|
||||
with open(settings.CLEAN_DATASET_PATH, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def load_poetry_from_dirty(self) -> list[str]:
|
||||
def __load_from_dirty(self) -> tuple[list[str], Tokenizer]:
|
||||
"""从原始数据加载,清洗数据后,写入缓存文件,并返回清洗后的数据"""
|
||||
with open(self.get_dirty_dataset_path(), 'r', encoding='utf-8') as f:
|
||||
# 加载脏的古诗数据
|
||||
with open(settings.DIRTY_DATASET_PATH, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
# 将冒号统一成相同格式
|
||||
lines = [line.replace(':', ':') for line in lines]
|
||||
|
||||
# 清洗古诗数据
|
||||
poetry = self.__wash_dirty_poetry(lines)
|
||||
# 构建分词器
|
||||
tokenizer = self.__build_tokenizer(poetry)
|
||||
|
||||
# 数据清理完毕
|
||||
# 写入干净数据
|
||||
with open(settings.CLEAN_DATASET_PATH, 'wb') as f:
|
||||
pickle.dump((poetry, tokenizer), f)
|
||||
|
||||
# 返回结果
|
||||
return poetry, tokenizer
|
||||
|
||||
def __wash_dirty_poetry(self, poetry: list[str]) -> list[str]:
|
||||
"""
|
||||
清洗给定的古诗数据。
|
||||
|
||||
:param poetry: 要清洗的古诗数据,每一行是一首古诗。
|
||||
古诗开头是标题,然后是一个冒号(全角或半角),然后是古诗主体。
|
||||
:return: 清洗完毕的古诗。
|
||||
"""
|
||||
# 禁用词列表,包含如下字符的诗歌将被忽略
|
||||
BAD_WORDS = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
|
||||
# 数据集列表
|
||||
poetry: list[str] = []
|
||||
clean_poetry: list[str] = []
|
||||
|
||||
# 逐行处理读取到的数据
|
||||
for line in lines:
|
||||
for line in poetry:
|
||||
# 删除空白字符
|
||||
line = line.strip()
|
||||
# 将全角冒号替换为半角的
|
||||
line = line.replace(':', ':')
|
||||
# 有且只能有一个冒号用来分割标题
|
||||
if line.count(':') != 1: continue
|
||||
# 获取后半部分(删除标题)
|
||||
_, last_part = line.split(':')
|
||||
# 长度不能超过最大长度
|
||||
if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2:
|
||||
# 长度不能超过最大长度(减去2是因为古诗首尾要加特殊符号)
|
||||
if len(last_part) > settings.POETRY_MAX_LEN - 2:
|
||||
continue
|
||||
# 不能包含禁止词
|
||||
for bad_word in PoetryDataset.BAD_WORDS:
|
||||
for bad_word in BAD_WORDS:
|
||||
if bad_word in last_part:
|
||||
break
|
||||
else:
|
||||
# 如果循环正常结束,就表明没有bad words,推入数据列表
|
||||
poetry.append(last_part)
|
||||
clean_poetry.append(last_part)
|
||||
|
||||
# 数据清理完毕
|
||||
# 写入干净数据
|
||||
with open(self.get_clean_dataset_path(), 'wb') as f:
|
||||
pickle.dump(poetry, f)
|
||||
# 返回结果
|
||||
return poetry
|
||||
# 返回清洗完毕的结果
|
||||
return clean_poetry
|
||||
|
||||
def get_clean_dataset_path(self) -> Path:
|
||||
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
|
||||
def __build_tokenizer(self, poetry: list[str]) -> Tokenizer:
|
||||
"""
|
||||
根据给定古诗统计词频,并构建分词器。
|
||||
|
||||
def get_dirty_dataset_path(self) -> Path:
|
||||
return Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
|
||||
|
||||
def build_tokenizer(self, poetry: list[str]) -> Tokenizer:
|
||||
"""统计词频,并构建分词器"""
|
||||
:param poetry: 清洗完毕后的古诗,每一行是一句诗。
|
||||
:return: 构建完毕的分词器。
|
||||
"""
|
||||
# 统计词频
|
||||
counter: Counter[str] = Counter()
|
||||
for line in poetry:
|
||||
counter.update(line)
|
||||
# 过滤掉低频词
|
||||
tokens = ((token, count) for token, count in counter.items() if count >= PoetryDataset.MIN_WORD_FREQ)
|
||||
tokens = ((token, count) for token, count in counter.items() if count >= settings.POETRY_MIN_WORD_FREQ)
|
||||
# 按词频排序
|
||||
tokens = sorted(tokens, key=lambda x: -x[1])
|
||||
# 去掉词频,只保留词列表
|
||||
@@ -169,3 +194,75 @@ class PoetryDataset:
|
||||
# 直接返回,此处无需混洗数据
|
||||
return tokenizer
|
||||
|
||||
class PoetryDataset(Dataset):
|
||||
"""适配PyTorch的古诗词Dataset"""
|
||||
|
||||
preprocessor: PoetryPreprocessor
|
||||
|
||||
def __init__(self, poetry: PoetryPreprocessor):
|
||||
self.preprocessor = poetry
|
||||
|
||||
def __getitem__(self, index):
|
||||
# 获取古诗词并编码
|
||||
poetry = self.preprocessor.poetry[index]
|
||||
encoded_poetry = self.preprocessor.tokenizer.encode(poetry)
|
||||
# 直接返回编码后的古诗词数据,数据的padding和输入输出构成由DataLoader来做。
|
||||
return encoded_poetry
|
||||
|
||||
def __len__(self):
|
||||
return len(self.preprocessor.poetry)
|
||||
|
||||
|
||||
class PoetryDataLoader:
|
||||
"""适配PyTorch的古诗词数据Loader"""
|
||||
|
||||
preprocessor: PoetryPreprocessor
|
||||
dataset: PoetryDataset
|
||||
loader: DataLoader
|
||||
|
||||
def __init__(self, batch_size: int, force_reclean: bool=False):
|
||||
self.preprocessor = PoetryPreprocessor(force_reclean)
|
||||
self.dataset = PoetryDataset(self.preprocessor)
|
||||
self.loader = DataLoader(dataset=self.dataset,
|
||||
batch_size=batch_size,
|
||||
# 对古诗词做padding后返回
|
||||
collate_fn=lambda batch: self.__collect_fn(batch),
|
||||
# 混洗数据以防止过拟合
|
||||
shuffle=True)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
"""一个便捷的获取vocab_size的函数,避免层层调用"""
|
||||
return self.preprocessor.tokenizer.vocab_size
|
||||
|
||||
def __collect_fn(self, batch: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
适用于DataLoader的样本收集器。
|
||||
用于将上传的古诗词样本做padding后打包返回。
|
||||
"""
|
||||
# 计算填充长度
|
||||
length = max(map(len, batch))
|
||||
# 获取填充数据
|
||||
padding = self.preprocessor.tokenizer.token_to_id(TOKEN_PAD)
|
||||
# 开始填充
|
||||
padded_batch: list[list[int]] = []
|
||||
for entry in batch:
|
||||
padding_length = length - len(entry)
|
||||
if padding_length > 0:
|
||||
# 不足就进行填充
|
||||
padded_batch.append(numpy.concatenate([entry, [padding] * padding_length]))
|
||||
else:
|
||||
# 超过就进行截断
|
||||
padded_batch.append(entry[:length])
|
||||
numpy_batch = numpy.array(padded_batch)
|
||||
|
||||
# 生成输入和输出。
|
||||
# 输入是去除最后一个字符的部分,输出是去除第一个字符的部分。
|
||||
# 这么做是为了让RNN从输入推到输出(下一个字符)。
|
||||
# 此外,输出要做onehot编码
|
||||
input = torch.tensor(numpy_batch[:, :-1], dtype=torch.long)
|
||||
output = F.one_hot(torch.tensor(numpy_batch[:, 1:], dtype=torch.long),
|
||||
num_classes=self.preprocessor.tokenizer.vocab_size).float()
|
||||
|
||||
# 返回结果
|
||||
return input, output
|
||||
|
||||
|
||||
@@ -1,17 +1,41 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class TimeDistributed(torch.nn.Module):
|
||||
"""模拟tensorflow中的TimeDistributed包装层,因为pytorch似乎不提供这个。"""
|
||||
|
||||
layer: torch.nn.Module
|
||||
"""内部节点"""
|
||||
|
||||
def __init__(self, layer: torch.nn.Module):
|
||||
super(TimeDistributed, self).__init__()
|
||||
self.layer = layer
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# 获取批次大小,时间步个数,特征个数
|
||||
batch_size, time_steps, features = x.size()
|
||||
# 把时间步维度合并到批次维度中然后运算,这样在其他层看来这就是不同的批次而已。
|
||||
x = x.reshape(-1, features)
|
||||
outputs: torch.Tensor = self.layer(x)
|
||||
# 再把时间步维度还原出来
|
||||
outputs = outputs.reshape(batch_size, time_steps, -1)
|
||||
return outputs
|
||||
|
||||
|
||||
class Rnn(torch.nn.Module):
|
||||
def __init__(self, vocab_size):
|
||||
"""循环神经网络"""
|
||||
|
||||
def __init__(self, vocab_size: int):
|
||||
super(Rnn, self).__init__()
|
||||
self.embedding = torch.nn.Embedding(vocab_size, 128)
|
||||
self.lstm1 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
|
||||
self.lstm2 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
|
||||
self.fc = torch.nn.Linear(128, vocab_size)
|
||||
self.timedfc = TimeDistributed(torch.nn.Linear(128, vocab_size))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embedding(x)
|
||||
x, _ = self.lstm1(x)
|
||||
x, _ = self.lstm2(x)
|
||||
x = self.fc(x)
|
||||
x = self.timedfc(x)
|
||||
return x
|
||||
|
||||
|
||||
0
exp3/modified/predict.py
Normal file
0
exp3/modified/predict.py
Normal file
@@ -1,14 +1,19 @@
|
||||
from pathlib import Path
|
||||
|
||||
POETRY_MAX_LEN: int = 64
|
||||
"""古诗词句子最大允许长度(该长度包含首尾填充的特殊字符),超过该长度的诗句将被删除。"""
|
||||
POETRY_MIN_WORD_FREQ: int = 8
|
||||
"""古诗词最小允许词频,小于该词频的词将在编解码时被视为[UNK]生僻字。"""
|
||||
|
||||
DIRTY_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.txt'
|
||||
"""脏的(未清洗的)古诗数据的路径"""
|
||||
CLEAN_DATASET_PATH: Path = Path(__file__).resolve().parent.parent / 'datasets' / 'poetry.pickle'
|
||||
"""干净的(已经清洗过的)古诗数据的路径"""
|
||||
|
||||
SAVED_MODULE_PATH: Path = Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
|
||||
SAVED_MODEL_PATH: Path = Path(__file__).resolve().parent.parent / 'models' / 'rnn.pth'
|
||||
"""训练完毕的模型进行保存的路径"""
|
||||
|
||||
N_EPOCH: int = 10
|
||||
"""训练时的epoch"""
|
||||
N_BATCH_SIZE: int = 16
|
||||
N_BATCH_SIZE: int = 50
|
||||
"""训练时的batch size"""
|
||||
|
||||
68
exp3/modified/train.py
Normal file
68
exp3/modified/train.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import typing
|
||||
import torch
|
||||
import torchinfo
|
||||
import ignite.engine
|
||||
import ignite.metrics
|
||||
from ignite.engine import Engine, Events
|
||||
from ignite.handlers.tqdm_logger import ProgressBar
|
||||
from dataset import PoetryDataLoader
|
||||
from model import Rnn
|
||||
import settings
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||
import gpu_utils
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""核心训练器"""
|
||||
|
||||
device: torch.device
|
||||
data_loader: PoetryDataLoader
|
||||
model: Rnn
|
||||
|
||||
trainer: Engine
|
||||
pbar: ProgressBar
|
||||
|
||||
def __init__(self):
|
||||
# 创建训练设备,模型和数据加载器。
|
||||
self.device = gpu_utils.get_gpu_device()
|
||||
self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE)
|
||||
self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device)
|
||||
# 展示模型结构。批次为指定批次数量,通道只有一个灰度通道,大小28x28。
|
||||
torchinfo.summary(self.model,
|
||||
(settings.N_BATCH_SIZE, settings.POETRY_MAX_LEN),
|
||||
dtypes=[torch.int32,])
|
||||
|
||||
# 优化器和损失函数
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), eps=1e-7)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
# 创建训练器
|
||||
self.trainer = ignite.engine.create_supervised_trainer(
|
||||
self.model, optimizer, criterion, self.device)
|
||||
# 将训练器关联到进度条
|
||||
self.pbar = ProgressBar(persist=True)
|
||||
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
|
||||
|
||||
def train_model(self):
|
||||
# 训练模型
|
||||
self.trainer.run(self.data_loader.loader, max_epochs=settings.N_EPOCH)
|
||||
|
||||
def save_model(self):
|
||||
# 确保保存模型的文件夹存在。
|
||||
settings.SAVED_MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
# 仅保存模型参数
|
||||
torch.save(self.model.state_dict(), settings.SAVED_MODEL_PATH)
|
||||
print(f'Model was saved into: {settings.SAVED_MODEL_PATH}')
|
||||
|
||||
|
||||
def main():
|
||||
trainer = Trainer()
|
||||
trainer.train_model()
|
||||
trainer.save_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gpu_utils.print_gpu_availability()
|
||||
main()
|
||||
Reference in New Issue
Block a user