1
0

use ignite for exp2

This commit is contained in:
2025-12-02 23:07:27 +08:00
parent 43b807679f
commit 65c56e938c
15 changed files with 246 additions and 794 deletions

17
exp3/modified/model.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
class Rnn(torch.nn.Module):
def __init__(self, vocab_size):
super(Rnn, self).__init__()
self.embedding = torch.nn.Embedding(vocab_size, 128)
self.lstm1 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
self.lstm2 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
self.fc = torch.nn.Linear(128, vocab_size)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm1(x)
x, _ = self.lstm2(x)
x = self.fc(x)
return x