<p><code>nn.Module</code>API似乎工作正常,但在<code>L1Penalty</code>{<cd3>}方法中不应返回None</p>
<pre><code>import torch, torch.nn as nn
class L1Penalty(torch.autograd.Function):
@staticmethod
def forward(ctx, input, l1weight = 0.1):
ctx.save_for_backward(input)
ctx.l1weight = l1weight
return input
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
grad_input = input.clone().sign().mul(ctx.l1weight)
grad_input+=grad_output
return grad_input
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10,10)
self.fc2 = nn.Linear(10,6)
self.fc3 = nn.Linear(6,10)
self.fc4 = nn.Linear(10,10)
self.relu = nn.ReLU(inplace=True)
self.penalty = L1Penalty()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.penalty.apply(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
x = self.relu(x)
return x
model = Model()
a = torch.rand(50,10)
b = model(a)
print(b.shape)
</code></pre>