finish exp3 predict code
This commit is contained in:
@@ -6,6 +6,7 @@ import torch.nn.functional as F
|
|||||||
from PIL import Image, ImageFile
|
from PIL import Image, ImageFile
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from model import Cnn
|
from model import Cnn
|
||||||
|
import settings
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||||
import gpu_utils
|
import gpu_utils
|
||||||
@@ -53,8 +54,7 @@ class Predictor:
|
|||||||
self.model = Cnn().to(self.device)
|
self.model = Cnn().to(self.device)
|
||||||
|
|
||||||
# 加载保存好的模型参数
|
# 加载保存好的模型参数
|
||||||
file_path = Path(__file__).resolve().parent.parent / 'models' / 'cnn.pth'
|
self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH))
|
||||||
self.model.load_state_dict(torch.load(file_path))
|
|
||||||
|
|
||||||
def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
|
def __predict_tensor(self, in_data: torch.Tensor) -> PredictResult:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -233,6 +233,10 @@ class PoetryDataLoader:
|
|||||||
def get_vocab_size(self) -> int:
|
def get_vocab_size(self) -> int:
|
||||||
"""一个便捷的获取vocab_size的函数,避免层层调用"""
|
"""一个便捷的获取vocab_size的函数,避免层层调用"""
|
||||||
return self.preprocessor.tokenizer.vocab_size
|
return self.preprocessor.tokenizer.vocab_size
|
||||||
|
|
||||||
|
def get_tokenizer(self) -> Tokenizer:
|
||||||
|
"""一个便捷的获取Tokenizer的函数,避免层层调用"""
|
||||||
|
return self.preprocessor.tokenizer
|
||||||
|
|
||||||
def __collect_fn(self, batch: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
|
def __collect_fn(self, batch: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -0,0 +1,126 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import numpy
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import settings
|
||||||
|
from dataset import Tokenizer, PoetryDataLoader
|
||||||
|
from model import Rnn
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||||
|
import gpu_utils
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_poetry(tokenizer: Tokenizer, model: Rnn, device: torch.device, s: str=''):
|
||||||
|
"""
|
||||||
|
随机生成一首诗
|
||||||
|
|
||||||
|
:param tokenizer: 分词器
|
||||||
|
:param model: 用于生成古诗的模型
|
||||||
|
:param s: 用于生成古诗的起始字符串,默认为空串
|
||||||
|
:return: 一个字符串,表示一首古诗
|
||||||
|
"""
|
||||||
|
# 将初始字符串转成token
|
||||||
|
token_ids = tokenizer.encode(s)
|
||||||
|
|
||||||
|
# 去掉结束标记[SEP]
|
||||||
|
token_ids = token_ids[:-1]
|
||||||
|
while len(token_ids) < settings.POETRY_MAX_LEN:
|
||||||
|
# 进行预测,其中batch_size=1
|
||||||
|
input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
|
||||||
|
output: torch.Tensor = model(input.to(device))
|
||||||
|
# 计算最后一个字符的概率分布。
|
||||||
|
# 由于后续预测概率时,需要批次维度,所以方括号里第一项写:保留批次维度。
|
||||||
|
# 然后因为只有最后一个字符是预测的,其他字符都是辅助推断的,所以方括号第二项-1表示取这个最后一个字符。
|
||||||
|
# 最后,它的概率分布中不包含[PAD][UNK][CLS]的概率分布,所以方括号第三项3:把这些东西删掉(这些编号是Tokenizer在编译时写死的,详细查看对应模块)。
|
||||||
|
possibilities = F.softmax(output[:, -1, 3:])
|
||||||
|
# 按照预测出的概率,随机选择一个词作为预测结果。
|
||||||
|
# 如果需要贪心,则用argmax替代。
|
||||||
|
target_index = torch.multinomial(possibilities, num_samples=1)
|
||||||
|
# 记得把之前删除的维度加回来才是token id
|
||||||
|
target_id = target_index + 3
|
||||||
|
|
||||||
|
# 把target_id加入序列
|
||||||
|
token_ids.append(target_id)
|
||||||
|
# 如果target_id是[SEP],表示输出结束,需要退出
|
||||||
|
if target_id == 3: break
|
||||||
|
|
||||||
|
# 解码并返回结果
|
||||||
|
return tokenizer.decode(token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_acrostic(tokenizer: Tokenizer, model: Rnn, device: torch.device, head: str):
|
||||||
|
"""
|
||||||
|
随机生成一首藏头诗
|
||||||
|
|
||||||
|
:param tokenizer: 分词器
|
||||||
|
:param model: 用于生成古诗的模型
|
||||||
|
:param head: 藏头诗的头
|
||||||
|
:return: 一个字符串,表示一首古诗
|
||||||
|
"""
|
||||||
|
# 使用空串初始化token_ids
|
||||||
|
token_ids = tokenizer.encode('')
|
||||||
|
# 去掉结束标记[SEP],只保留[CLS]
|
||||||
|
token_ids = token_ids[:-1]
|
||||||
|
|
||||||
|
# 标点符号,这里简单的只把逗号和句号作为标点
|
||||||
|
punctuations = [',', '。']
|
||||||
|
punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}
|
||||||
|
|
||||||
|
# 缓存生成的诗的list
|
||||||
|
poetry: list[str] = []
|
||||||
|
# 对于藏头诗中的每一个字,都生成一个短句
|
||||||
|
for ch in head:
|
||||||
|
# 先记录下这个字
|
||||||
|
poetry.append(ch)
|
||||||
|
# 将藏头诗的字符转成token id
|
||||||
|
token_id = tokenizer.token_to_id(ch)
|
||||||
|
# 加入到列表中去
|
||||||
|
token_ids.append(token_id)
|
||||||
|
|
||||||
|
# 开始生成一个短句
|
||||||
|
while True:
|
||||||
|
# 与generate_random_poetry函数相同的方式,不断地生成诗句的下一个字。
|
||||||
|
input = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
|
||||||
|
output: torch.Tensor = model(input.to(device))
|
||||||
|
possibilities = F.softmax(output[:, -1, 3:])
|
||||||
|
target_index = torch.multinomial(possibilities, num_samples=1)
|
||||||
|
target_id = target_index + 3
|
||||||
|
|
||||||
|
# 把target_id加入序列
|
||||||
|
token_ids.append(target_id)
|
||||||
|
# 只有对应ID不是特殊符号的ID,我们才把这个字符推入诗句中
|
||||||
|
if target_id > 3: poetry.append(tokenizer.id_to_token(target_id))
|
||||||
|
# 此外,与上面不同的是,当输出为标点符号时,我们退出当前循环,进而生成藏头诗的下一句。
|
||||||
|
if target_id in punctuation_ids: break
|
||||||
|
|
||||||
|
# 解码并返回结果
|
||||||
|
return ''.join(poetry)
|
||||||
|
|
||||||
|
|
||||||
|
class Predictor:
|
||||||
|
device: torch.device
|
||||||
|
data_loader: PoetryDataLoader
|
||||||
|
model: Rnn
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.device = gpu_utils.get_gpu_device()
|
||||||
|
self.data_loader = PoetryDataLoader(batch_size=settings.N_BATCH_SIZE)
|
||||||
|
self.model = Rnn(self.data_loader.get_vocab_size()).to(self.device)
|
||||||
|
|
||||||
|
# 加载保存好的模型参数
|
||||||
|
self.model.load_state_dict(torch.load(settings.SAVED_MODEL_PATH))
|
||||||
|
|
||||||
|
def generate_random_poetry(self):
|
||||||
|
"""随机生成一首诗"""
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_random_poetry(self.data_loader.get_tokenizer(),
|
||||||
|
self.model,
|
||||||
|
self.device)
|
||||||
|
|
||||||
|
def generate_acrostic(self):
|
||||||
|
"""随机生成一首藏头诗"""
|
||||||
|
with torch.no_grad():
|
||||||
|
generate_acrostic(self.data_loader.get_tokenizer(),
|
||||||
|
self.model,
|
||||||
|
self.device)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from ignite.engine import Engine, Events
|
|||||||
from ignite.handlers.tqdm_logger import ProgressBar
|
from ignite.handlers.tqdm_logger import ProgressBar
|
||||||
from dataset import PoetryDataLoader
|
from dataset import PoetryDataLoader
|
||||||
from model import Rnn
|
from model import Rnn
|
||||||
|
from predict import generate_random_poetry
|
||||||
import settings
|
import settings
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
|
||||||
@@ -44,6 +45,11 @@ class Trainer:
|
|||||||
# 将训练器关联到进度条
|
# 将训练器关联到进度条
|
||||||
self.pbar = ProgressBar(persist=True)
|
self.pbar = ProgressBar(persist=True)
|
||||||
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
|
self.pbar.attach(self.trainer, output_transform=lambda loss: {"loss": loss})
|
||||||
|
# 每次epoch后,作诗一首看看结果
|
||||||
|
self.trainer.add_event_handler(
|
||||||
|
Events.EPOCH_COMPLETED,
|
||||||
|
lambda: generate_random_poetry(self.data_loader.get_tokenizer(), self.model, )
|
||||||
|
)
|
||||||
|
|
||||||
def train_model(self):
|
def train_model(self):
|
||||||
# 训练模型
|
# 训练模型
|
||||||
|
|||||||
Reference in New Issue
Block a user