递归网络LSTM中的归一化

2024-10-03 04:32:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图在PyTorch中的堆叠LSTM网络层之间正常化。网络看起来像这样:

class LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.lstm1 = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        self.lstm2 = nn.LSTM(input_dim, hidden_dim, hidden_dim, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim, 32)
        self.fc2 = nn.Linear(32, 1)
        self.dropout = nn.Dropout(p=0.2)
        self.batch_normalisation1 = nn.BatchNorm1d(hidden_dim)
        self.batch_normalisation2 = nn.BatchNorm1d(hidden_dim)

    def forward(self, x):
        h0, c0 = self.init_hidden(x)
        out, (hn1, cn1) = self.lstm1(x, (h0, c0))
        out = self.dropout(out)                     # error line
        out = self.batch_normalisation1(out)
        
        h1, c1 = self.init_hidden(out)
        out, (hn2, cn2) = self.lstm2(out, (h1, c1))
        out = self.dropout(out)
        out = self.batch_normalisation1(out)
        
        h2, c2 = self.init_hidden(out)
        out, (hn3, cn3) = self.lstm2(out, (h2, c2))
        out = self.dropout(out)
        out = self.batch_normalisation1(out)
        
        out = self.fc1(out[:, -1, :])
        out = self.dropout(out)
        out = self.fc2(out)
        return out
    
    def init_hidden(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim)
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim)
        return [t for t in (h0, c0)]

我在上面的评论中提到了一个错误,这是因为BatchNorma1D需要一个二维输入

特别是,我正在初始化模型并将批量数据传递到网络,如下所示:

model = LSTMClassifier(5, 128, 3, 1)
model(X)

错误: RuntimeError: running_mean should contain 3 elements not 128

X输入张量具有 torch.Size([10, 3, 5])形状,即批量大小为10,每个输入具有3 X 5维度,即5个特征和3个时间步

该错误是由于BatchNormalD试图通过错误的维度进行标准化而产生的-在网络中,变量out具有shapetorch.Size([1, 3, 128]),即5个输入特征映射到128个超变量

我可以重塑forward函数中的变量,但这似乎没有必要。我也尝试过使用BatchNorm2d,但它需要4d张量,而我的变量out不是。有没有办法克服这个问题

此外,我正在尝试在我的网络中添加规范化以加快培训-我不完全确定Pyrotch BatchNorm函数是如何工作的,因此希望您能解释一下这个v。很具体来说,为什么我们要跨时间维度而不是特征维度进行标准化


Tags: self网络layerinit错误batchnnout