如何在由原始张量运算(如not)构建的PyTorch LSTM上进行反投影`nn.LSTM公司`)

2024-09-28 21:43:19 发布

您现在位置:Python中文网/ 问答频道 /正文

编辑:我的实现的问题是试图直接从隐藏状态提取我的output,即一个热向量。我在上面加了一个致密层,效果很好。你知道吗


我正在尝试从PyTorch中的更原始的操作生成LSTM,并使用torch.autograd特性来反传错误。我希望它是“在线的”,因为hc随着时间的推移,它们的状态会不断累积,每个时间步有1个字符输入,1个字符输出。你知道吗

这是字符级rnn,因此:

  • 我的词汇量是30个字符(小写a-z和一些标点符号)

  • inp是长度为30的onehot向量。

  • hc的长度为30+100。h的前30个是我的“输出”

  • 我的损失是将热编码的target字符与h的前30个索引进行比较。

  • 我在10个步骤中积累了loss,然后就不知道如何正确地反向支撑它。以下是一个(糟糕的)尝试。

TL;博士,我该如何正确地支持这个LSTM?

    def ff(inp, h, c):
        xh = torch.cat((inp, h), 0)
        f = (xh @ Wf + bf).sigmoid()
        i = (xh @ Wi + bi).sigmoid()
        g = (xh @ Wg + bg).tanh() # C-bar, in some literature
        c = f * c + i * g
        o = (xh @ Wo + bo).sigmoid()
        h = o * c.tanh()
        return h, c

    loss = torch.zeros(1)
    def bp(out, target, lr):
        global Wf, Wi, Wg, Wo
        global bf, bi, bg, bo
        global h, c
        global loss

        # Accumulate loss every step
        loss += (-target * out[:out_n].softmax(dim=0).log()).sum()

        # Every 10 chars, run backprop
        if i % 10 == 0:
            loss.backward()

            with torch.no_grad():
                for param in [Wf, Wi, Wg, Wo, bf, bi, bg, bo]:
                    param -= lr * param.grad
                    param.grad.zero_()

            h.detach_()
            c.detach_()
            loss.detach_()
            loss = torch.zeros(1)

        return loss

Tags: targetparamtorchglobalbginpwfbi