TypeError:mul():参数“other”(位置2)必须是张量,而不是元组

2024-09-28 05:16:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试乘以LSTM门,我收到了上面的错误

出于某种原因,我似乎在用一个张量乘以一个元组,但我目前无法解决这个问题

为了更清楚地了解情况,我添加了有关此错误的所有相关信息,希望它能帮助您找到解决方案

完全错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-7-93be62737b80> in <module>()
     31         #     orth_opt.zero_grad()
     32 
---> 33         logits = model(batch_x)
     34         # mse_loss = model.loss(logits, batch_y, len_batch)
     35         #

3 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-4-42628aa60e46> in forward(self, inputs)
    160               output, state = self.LSTMCell(input, state, self.mask)
    161             else:
--> 162               output, state = self.LSTMCell(input, state)
    163 
    164             state = (output, state)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-3-fb6549fbebf9> in forward(self, input, state, mask)
     93         outgate = torch.sigmoid(outgate)
     94 
---> 95         cy = torch.mul(forgetgate, cx) + torch.mul(ingate, cellgate)
     96         hy = outgate * torch.tanh(cy)
     97 

TypeError: mul(): argument 'other' (position 2) must be Tensor, not tuple

我的培训代码:

model.train()

opt = torch.optim.Adam(model.parameters(), lr=args.lr)

best_test = 1e7
best_validation = 1e7

for ep in range(1, args.epochs + 1):

    init_time = datetime.now()
    processed = 0
    step = 1

    for batch_idx, (batch_x, batch_y, len_batch) in enumerate(train_loader):
        batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(device), len_batch.to(device)

        opt.zero_grad()

        logits = model(batch_x)

        loss = model.loss(logits, batch_y, len_batch)

        acc = sum(logits == batch_y) * 1.0 / len(logits)
        print(acc)

        loss.backward()

        if args.clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        opt.step()

        processed += len(batch_x)
        step += 1
        print("   batch_idx {}\tLoss: {:.2f} ".format(batch_idx, loss))

    print("Epoch {}, LR {:.5f} \tLoss: {:.2f} ".format(ep, opt.param_groups[0]['lr'], loss))

我的ModLSTM类:

class ModLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ModLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = Parameter(torch.randn(4 * hidden_size))

    def forward(self, input, state, mask=None):
        if state is None:
            hx, cx = (
                torch.zeros(input.size(0), self.hidden_size),
                torch.zeros(input.size(0), self.hidden_size),

            )  
        else:
            hx, cx = state

        if mask is not None:
            self.weight_hh.data = self.weight_hh * mask

        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = torch.mul(forgetgate, cx) + torch.mul(ingate, cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)

任何帮助都将不胜感激,因为我目前在这方面已经坚持了好几天

多谢各位


Tags: inselfinputsizemodellenbatchtorch

热门问题