我的Pytork forward功能是否可以执行其他操作?

2024-10-01 05:00:54 发布

您现在位置:Python中文网/ 问答频道 /正文

通常forward函数将一组层串在一起,并返回最后一个层的输出。在返回前,我能在最后一层之后做一些额外的处理吗?例如,通过.view进行一些标量乘法和整形

我知道autograd不知怎么算出了梯度。所以我不知道我的额外处理是否会把事情搞砸。谢谢


Tags: 函数view事情forward梯度标量乘法autograd
1条回答
网友
1楼 · 发布于 2024-10-01 05:00:54

通过张量computational graph跟踪梯度,而不是通过函数。只要你的张量有^{}属性,并且它们的^{}不是None,你可以(几乎)做任何你喜欢的事情,并且仍然能够支持。
只要您使用pytorch的操作(例如herehere中列出的操作),您就应该可以了

有关更多信息,请参见this

例如(取自torchvision's VGG implementation):

class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        #  ...

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # <  what you were asking about
        x = self.classifier(x)
        return x

torchvision's implementation of ResNet中可以看到一个更复杂的例子:

class Bottleneck(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        # ...

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:    # <  conditional execution!
            identity = self.downsample(x)

        out += identity  # <  inplace operations
        out = self.relu(out)

        return out

相关问题 更多 >