18 lines
542 B
Python
18 lines
542 B
Python
|
|
import torch
|
||
|
|
|
||
|
|
class Rnn(torch.nn.Module):
|
||
|
|
def __init__(self, vocab_size):
|
||
|
|
super(Rnn, self).__init__()
|
||
|
|
self.embedding = torch.nn.Embedding(vocab_size, 128)
|
||
|
|
self.lstm1 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
|
||
|
|
self.lstm2 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
|
||
|
|
self.fc = torch.nn.Linear(128, vocab_size)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.embedding(x)
|
||
|
|
x, _ = self.lstm1(x)
|
||
|
|
x, _ = self.lstm2(x)
|
||
|
|
x = self.fc(x)
|
||
|
|
return x
|
||
|
|
|