use ignite for exp2
This commit is contained in:
@@ -87,7 +87,7 @@ class PoetryDataset:
|
||||
tokenizer: Tokenizer
|
||||
"""分词器"""
|
||||
poetry: list[str]
|
||||
"""古诗词数据集"""
|
||||
"""古诗词数据集,每一项是一首诗"""
|
||||
|
||||
def __init__(self, force_reclean: bool = False):
|
||||
# 加载古诗,然后统计词频构建分词器
|
||||
@@ -121,7 +121,7 @@ class PoetryDataset:
|
||||
line = line.strip()
|
||||
# 有且只能有一个冒号用来分割标题
|
||||
if line.count(':') != 1: continue
|
||||
# 获取后半部分
|
||||
# 获取后半部分(删除标题)
|
||||
_, last_part = line.split(':')
|
||||
# 长度不能超过最大长度
|
||||
if len(last_part) > PoetryDataset.MAX_SEG_LEN - 2:
|
||||
|
||||
17
exp3/modified/model.py
Normal file
17
exp3/modified/model.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
class Rnn(torch.nn.Module):
|
||||
def __init__(self, vocab_size):
|
||||
super(Rnn, self).__init__()
|
||||
self.embedding = torch.nn.Embedding(vocab_size, 128)
|
||||
self.lstm1 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
|
||||
self.lstm2 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
|
||||
self.fc = torch.nn.Linear(128, vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embedding(x)
|
||||
x, _ = self.lstm1(x)
|
||||
x, _ = self.lstm2(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user