如何在nn.顺序模型中使用自定义torch.autograd.Function

2024-09-28 22:54:21 发布

您现在位置:Python中文网/ 问答频道 /正文

是否有任何方法可以在nn.Sequential对象中使用自定义torch.autograd.Function,或者应该显式地使用带有forward函数的nn.Module对象。具体来说,我正在尝试实现一个稀疏自动编码器,我需要将代码的L1距离(隐藏表示)添加到丢失中。 我在下面定义了自定义torch.autograd.Function对象,然后尝试在nn.Sequential对象中使用它,如下所示。然而,当我运行时,我得到了错误TypeError: __main__.L1Penalty is not a Module subclass我如何解决这个问题

class L1Penalty(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, l1weight = 0.1):
        ctx.save_for_backward(input)
        ctx.l1weight = l1weight
        return input, None

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables
        grad_input = input.clone().sign().mul(ctx.l1weight)
        grad_input+=grad_output
        return grad_input
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 6),
    nn.ReLU(),
    # sparsity
    L1Penalty(),
    nn.Linear(6, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU()
).to(device)

Tags: 对象inputfunctionnntorchforwardmodulelinear
2条回答

正确的方法是这样做

import torch, torch.nn as nn

class L1Penalty(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, l1weight = 0.1):
        ctx.save_for_backward(input)
        ctx.l1weight = l1weight
        return input

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables
        grad_input = input.clone().sign().mul(ctx.l1weight)
        grad_input+=grad_output
        return grad_input

创建充当包装器的Lambda类

class Lambda(nn.Module):
    """
    Input: A Function
    Returns : A Module that can be used
        inside nn.Sequential
    """
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x): return self.func(x)

塔达

model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 6),
    nn.ReLU(),
    # sparsity
    Lambda(L1Penalty.apply),
    nn.Linear(6, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU())

a = torch.rand(50,10)
b = model(a)
print(b.shape)

nn.ModuleAPI似乎工作正常,但在L1Penalty{}方法中不应返回None

import torch, torch.nn as nn

class L1Penalty(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, l1weight = 0.1):
        ctx.save_for_backward(input)
        ctx.l1weight = l1weight
        return input

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables
        grad_input = input.clone().sign().mul(ctx.l1weight)
        grad_input+=grad_output
        return grad_input


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10,10)
        self.fc2 = nn.Linear(10,6)
        self.fc3 = nn.Linear(6,10)
        self.fc4 = nn.Linear(10,10)
        self.relu = nn.ReLU(inplace=True)
        self.penalty = L1Penalty()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.penalty.apply(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.fc4(x)
        x = self.relu(x)
        return x


model = Model()
a = torch.rand(50,10)
b = model(a)
print(b.shape)

相关问题 更多 >