1
0
Files
ai-school/dl-exp/exp3/modified/model.py

42 lines
1.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn.functional as F
class TimeDistributed(torch.nn.Module):
"""模拟tensorflow中的TimeDistributed包装层因为pytorch似乎不提供这个。"""
layer: torch.nn.Module
"""内部节点"""
def __init__(self, layer: torch.nn.Module):
super(TimeDistributed, self).__init__()
self.layer = layer
def forward(self, x: torch.Tensor):
# 获取批次大小,时间步个数,特征个数
batch_size, time_steps, features = x.size()
# 把时间步维度合并到批次维度中然后运算,这样在其他层看来这就是不同的批次而已。
x = x.reshape(-1, features)
outputs: torch.Tensor = self.layer(x)
# 再把时间步维度还原出来
outputs = outputs.reshape(batch_size, time_steps, -1)
return outputs
class Rnn(torch.nn.Module):
"""循环神经网络"""
def __init__(self, vocab_size: int):
super(Rnn, self).__init__()
self.embedding = torch.nn.Embedding(vocab_size, 128)
self.lstm1 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
self.lstm2 = torch.nn.LSTM(128, 128, batch_first=True, dropout=0.5)
self.timedfc = TimeDistributed(torch.nn.Linear(128, vocab_size))
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm1(x)
x, _ = self.lstm2(x)
x = self.timedfc(x)
return x