Pytork在损耗函数中使用自动加载时不更新权重

2024-10-02 18:17:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图使用网络相对于其输入的梯度作为损失函数的一部分。但是,每当我尝试计算它时,训练都会继续,但权重不会更新

import torch
import torch.optim as optim
import torch.autograd as autograd


ic = torch.rand((25, 3))
ic = torch.tensor(ic, requires_grad=True)
optimizer = optim.RMSprop([ic], lr=1e-2)

for itr in range(1, 50):
    optimizer.zero_grad()
    sol = torch.tanh(.5*torch.stack(100*[ic])) # simplified for minimal working example
    
    dx = sol[-1, :, 0]
    dxdxy, = autograd.grad(dx, 
                           inputs=ic,
                           grad_outputs = torch.ones(ic.shape[0]), # batchwise
                           retain_graph=True
                          )
    dxdxy = torch.tensor(dxdxy, requires_grad=True)
    loss = torch.sum(dxdxy)
    
    loss.backward()
    optimizer.step()
    
    if itr % 5 == 0:
        print(loss)

我做错了什么


Tags: importtrueforastorchoptimoptimizertensor
1条回答
网友
1楼 · 发布于 2024-10-02 18:17:29

当您运行autograd.grad而不将标志create_graph设置为True时,您将无法获得连接到计算图的输出,这意味着您将无法进一步优化w.r.tic(并获得您希望在此处实现的高阶导数)。 从torch.autograd.grad的文档字符串:

create_graph (bool, optional): If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: False.

在这里尝试使用dxdxy = torch.tensor(dxdxy, requires_grad=True)不会有帮助,因为连接到ic的计算图已经丢失(因为create_graphFalse),您所做的只是创建一个新的计算图,其中dxdxy是一个叶节点

请参阅下面附带的解决方案(请注意,当您创建ic时,可以设置requires_grad=True,因此第二行是冗余的(这不是逻辑问题,只是更长的代码):

import torch
import torch.optim as optim
import torch.autograd as autograd

ic = torch.rand((25, 3),requires_grad=True) #<  requires_grad to True here
#ic = torch.tensor(ic, requires_grad=True) #<  redundant
optimizer = optim.RMSprop([ic], lr=1e-2)

for itr in range(1, 50):
    optimizer.zero_grad()
    sol = torch.tanh(.5 * torch.stack(100 * [ic]))  # simplified for minimal working example

    dx = sol[-1, :, 0]
    dxdxy, = autograd.grad(dx,
                           inputs=ic,
                           grad_outputs=torch.ones(ic.shape[0]),  # batchwise
                           retain_graph=True, create_graph=True # <  important
                           )
    #dxdxy = torch.tensor(dxdxy, requires_grad=True) #<  won't do the trick. Remove
    loss = torch.sum(dxdxy)

    loss.backward()
    optimizer.step()

    if itr % 5 == 0:
        print(loss)

相关问题 更多 >