我在使用pytorch实现GRU网络时遇到问题:
我的代码如下:
import torch
class GRU_model(torch.nn.Module):
def __init__(self, device):
super(GRU_model, self).__init__()
self.h = torch.randn((1,1,5), device=device, dtype=torch.float)
self.GRU_1 = torch.nn.GRU(input_size=5, hidden_size=5)
def forward(self, a):
output, self.h = self.GRU_1(a, self.h)
return output
if __name__ == '__main__':
learn_rate=1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GRU_model(device).to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
for i in range(10):
a = torch.randn((1, 1, 5), device=device, dtype=torch.float)
output = model(a)
loss = (a - output).mean()
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
我犯了这样一个错误:
Traceback (most recent call last):
File "C:/Users/Administrator_/Desktop/Graduation_Project/MIDI_Music_style_transfer/GRU_toy_in-place_hidden_states_change/main.py", line 40, in <module>
loss.backward(retain_graph=True)
File "C:\Users\Administrator_\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\tensor.py", line 245, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "C:\Users\Administrator_\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\autograd\__init__.py", line 147, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [15, 5]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
我只想在一个纪元后更新GRU中的隐藏状态,但它就是不起作用
如果你能帮我,我将不胜感激
在PyTorch中,为一个历元中的每次迭代创建计算图。在每次迭代中,我们执行前向传递,计算输出w.r.t对网络参数的导数,并更新参数以适合给定示例。执行向后传递后,将释放图形以节省内存。在下一次迭代中,将创建一个新的图形,并准备进行反向传播
由于默认情况下,第一次向后传递后将释放计算图,因此如果第二次尝试在同一个图上向后传递,将遇到错误。这就是为什么会弹出以下错误消息:
source
在您的情况下,在指定
retain_graph=True
之后,您会看到:当您尝试在正向传递中更新
self.h
时,会出现此问题。您没有修改它inplace
,因为它是梯度计算所需要的source 这一条应该有效:真正的问题是隐藏态不应该参与梯度反向传播的计算。 因此,只需添加一行
self.h = self.h.detach()
,如下所示,这肯定会解决问题:相关问题 更多 >
编程相关推荐