如何在这个pytorch代码中消除forloop

2024-10-01 13:45:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我有这个定制的pytorch模块(如下)。它正是我所需要的;它只是做得很慢。我能做些什么来加快速度?我知道我不应该在那里有一个for循环;只是不清楚没有它怎么做除法运算。在没有这个循环的情况下,如何将x张量传播到除法?如果有帮助的话,我可以将后权重移动到它们自己的层

class StepLayer(nn.Module):
    def __init__(self):
        super(StepLayer, self).__init__()
        w = init_weights()
        self.front_weights = nn.Parameter(torch.DoubleTensor([w, w]).T, requires_grad=True)
        self.back_weights = nn.Parameter(torch.DoubleTensor([w]).T, requires_grad=True)
        

    def forward(self, x):
        # x shape is batch by feature
        results = []
        for batch in x:
            b = batch.divide(self.front_weights)
            b = torch.some_math_function(b)
            b = b.sum(dim=1)
            b = torch.some_other_math_function(b)
            b = b @ self.back_weights
            results.append(b)
        stack = torch.vstack(results)
        return stack

Tags: selfforparameterinitdefbatchnntorch
1条回答
网友
1楼 · 发布于 2024-10-01 13:45:39

下面是一个源代码与形状后,每个步骤描述(阅读代码注释请)

我假设了一些事情,比如F=100x=Bx2front_weights=100x2back_weights=100,您应该能够轻松地根据您的情况调整它

class StepLayer(nn.Module):
    def __init__(self):
        super().__init__()
        F = 100
        # Notice I've added `1` dimension in front_weights
        self.front_weights = nn.Parameter(torch.randn(1, F, 2), requires_grad=True)
        self.back_weights = nn.Parameter(torch.randn(F), requires_grad=True)

    def forward(self, x):
        # x.shape == (B, 2)
        x = x.unsqueeze(dim=1)  # (B, 1, 2)
        x = x / self.front_weights  # (B, F, 2)
        # I just took some element-wise math function from PyTorch
        x = torch.sin(x)  # (B, F, 2)
        x = torch.sum(x, dim=-1)  # (B, F)
        x = torch.sin(x)  # (B, F)
        return x @ self.back_weights  # (B, )

        # results = []
        # for batch in x:
        #     # batch - (1, 2)
        #     b = batch.divide(self.front_weights)  # (F, 2)
        #     b = torch.some_math_function(b)  # (F, 2)
        #     b = b.sum(dim=1)  # (F, )
        #     b = torch.some_other_math_function(b)  # (F, )
        #     b = b @ self.back_weights  # (1, )
        #     results.append(b)
        # stack = torch.vstack(results)  # (B, )
        # return stack  # (B,)


layer = StepLayer()

print(layer(torch.randn(64, 2)).shape)

主要技巧是在必要时使用1维度进行广播(特别是除法)和智能权重初始化,这样就不必执行任何转置操作

其他事情

  • 您可能需要重新考虑一下Doublefloat(如上所述)的速度要快得多,尤其是在CUDA上,并且占用了一半的内存(神经网络应该补偿精度损失,如果有的话)
  • 如果速度仍然是一个问题(float16dtype而不是float32),则使用half精度和混合训练,但仅在CUDA上使用,有关自动混合精度的更多信息,请参见here

相关问题 更多 >