import torch class Rnn(torch.nn.Module): def __init__(self, vocab_size): super(Rnn, self).__init__() self.embedding = torch.nn.Embedding(vocab_size, 128) self.lstm1 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5) self.lstm2 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5) self.fc = torch.nn.Linear(128, vocab_size) def forward(self, x): x = self.embedding(x) x, _ = self.lstm1(x) x, _ = self.lstm2(x) x = self.fc(x) return x