我正在尝试乘以LSTM门,我收到了上面的错误
出于某种原因,我似乎在用一个张量乘以一个元组,但我目前无法解决这个问题
为了更清楚地了解情况,我添加了有关此错误的所有相关信息,希望它能帮助您找到解决方案
完全错误:
TypeError Traceback (most recent call last)
<ipython-input-7-93be62737b80> in <module>()
31 # orth_opt.zero_grad()
32
---> 33 logits = model(batch_x)
34 # mse_loss = model.loss(logits, batch_y, len_batch)
35 #
3 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
<ipython-input-4-42628aa60e46> in forward(self, inputs)
160 output, state = self.LSTMCell(input, state, self.mask)
161 else:
--> 162 output, state = self.LSTMCell(input, state)
163
164 state = (output, state)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
<ipython-input-3-fb6549fbebf9> in forward(self, input, state, mask)
93 outgate = torch.sigmoid(outgate)
94
---> 95 cy = torch.mul(forgetgate, cx) + torch.mul(ingate, cellgate)
96 hy = outgate * torch.tanh(cy)
97
TypeError: mul(): argument 'other' (position 2) must be Tensor, not tuple
我的培训代码:
model.train()
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
best_test = 1e7
best_validation = 1e7
for ep in range(1, args.epochs + 1):
init_time = datetime.now()
processed = 0
step = 1
for batch_idx, (batch_x, batch_y, len_batch) in enumerate(train_loader):
batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(device), len_batch.to(device)
opt.zero_grad()
logits = model(batch_x)
loss = model.loss(logits, batch_y, len_batch)
acc = sum(logits == batch_y) * 1.0 / len(logits)
print(acc)
loss.backward()
if args.clip > 0:
nn.utils.clip_grad_norm_(model.parameters(), args.clip)
opt.step()
processed += len(batch_x)
step += 1
print(" batch_idx {}\tLoss: {:.2f} ".format(batch_idx, loss))
print("Epoch {}, LR {:.5f} \tLoss: {:.2f} ".format(ep, opt.param_groups[0]['lr'], loss))
我的ModLSTM类:
class ModLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(ModLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
def forward(self, input, state, mask=None):
if state is None:
hx, cx = (
torch.zeros(input.size(0), self.hidden_size),
torch.zeros(input.size(0), self.hidden_size),
)
else:
hx, cx = state
if mask is not None:
self.weight_hh.data = self.weight_hh * mask
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = torch.mul(forgetgate, cx) + torch.mul(ingate, cellgate)
hy = outgate * torch.tanh(cy)
return hy, (hy, cy)
任何帮助都将不胜感激,因为我目前在这方面已经坚持了好几天
多谢各位
目前没有回答
相关问题 更多 >
编程相关推荐