是否有任何方法可以在nn.Sequential
对象中使用自定义torch.autograd.Function
,或者应该显式地使用带有forward函数的nn.Module
对象。具体来说,我正在尝试实现一个稀疏自动编码器,我需要将代码的L1距离(隐藏表示)添加到丢失中。
我在下面定义了自定义torch.autograd.Function
对象,然后尝试在nn.Sequential
对象中使用它,如下所示。然而,当我运行时,我得到了错误TypeError: __main__.L1Penalty is not a Module subclass
我如何解决这个问题
class L1Penalty(torch.autograd.Function):
@staticmethod
def forward(ctx, input, l1weight = 0.1):
ctx.save_for_backward(input)
ctx.l1weight = l1weight
return input, None
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
grad_input = input.clone().sign().mul(ctx.l1weight)
grad_input+=grad_output
return grad_input
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 6),
nn.ReLU(),
# sparsity
L1Penalty(),
nn.Linear(6, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU()
).to(device)
正确的方法是这样做
创建充当包装器的Lambda类
塔达
nn.Module
API似乎工作正常,但在L1Penalty
{相关问题 更多 >
编程相关推荐