Pytork极其简单的backward()autograd引发CUDA OutOfMemory错误?

2024-06-28 21:00:00 发布

您现在位置:Python中文网/ 问答频道 /正文


我向你们提出了一个关于我编写的自定义、非常简单的PyTorch autograd函数的问题。autograd函数执行不可微光线相交,并根据我在纸上推导的解析表达式反向传播梯度。然而,当运行它时,即使使用非常简单的NN(没有隐藏层,2个输入节点,1个输出节点),我也会得到一个“CUDA内存不足”错误,不管我是在我的机器(6GB VRAM)上运行脚本还是在Google Colab(16GB(!)VRAM)上运行脚本。所以一定是出了什么问题。

要复制此问题,请参见下文。MLP预测某些东西,forward函数使用此预测执行交点并返回交点坐标。然后我们用坐标做更多的事情,但这并不重要。在向后传球中,我想计算交叉点坐标w.r.t.的梯度,MLP预测。这就是它失败的地方

非常感谢您的帮助

import torch
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Intersector(torch.autograd.Function):

    @staticmethod
    def forward(ctx, *input):

        (mlp_pred, other_stuff) = input[:2]
        found_intersections = torch.zeros((3, 250, 250), device=device)

        # not differentiable intersection routine
        # (just a dummy, doesn't really interesect anything, normally this would use mlp_pred)
        with torch.no_grad():
            for k in range(100):
                other_stuff += i
                if i > 20:
                    found_intersections += 15.0
                    break

        ctx.mlp_prediction = mlp_pred                   # save stuff for backward pass
        ctx.found_intersections = found_intersections

        return found_intersections

    @staticmethod
    def backward(ctx, grad_output):
        # grad_output contains grad of loss w.r.t the found intersections
        # now: compute gradient of loss w.r.t input, i.e., w.r.t. mlp_pred

        mlp_pred = ctx.mlp_prediction
        intersections = ctx.found_intersections

        with torch.enable_grad():

            # get mlp prediction values at intersection coordinates by sampling the prediction at these values
            grid_x = intersections[0, :, :].unsqueeze(dim=0)
            grid_y = intersections[1, :, :].unsqueeze(dim=0)
            sampling_grid = torch.cat((grid_x, grid_y), dim=0).permute(1, 2, 0).unsqueeze(dim=0)
            mlp_pred_at_intersections = F.grid_sample(mlp_pred, sampling_grid, mode='bilinear', padding_mode='zeros')

            pred_sum = mlp_pred_at_intersections.sum()      # sum bc autograd needs scalar

            ### THIS IS WHERE THE OOM ERROR OCCURS
            gradient_wrt_mlp_pred = torch.autograd.grad(pred_sum, mlp_pred, only_inputs=True, retain_graph=True)

            grad_for_backprop = gradient_wrt_mlp_pred[0].to(device) * grad_output

        # grad. for other_stuff not needed, return None
        return grad_for_backprop, None


class SimpleModel(torch.nn.Module):

    def __init__(self):
        super(SimpleModel, self).__init__()

        # extremely simple MLP
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=2, out_features=1),
            torch.nn.LeakyReLU(),
        )

        self.calc_intersections = Intersector.apply
        self.other_stuff = torch.zeros((3, 250, 250), device=device)

    def forward(self, fw_input):

        # get MLP prediction
        mlp_prediction = self.mlp(fw_input).reshape((1, 1, 250, 250))

        # get intersection points w.r.t mlp prediction
        inputs_to_intersection = [mlp_prediction, self.other_stuff]
        intersections = self.calc_intersections(*inputs_to_intersection)

        # do some differentiable stuff w/ predictions
        intersections_altered = intersections + torch.ones_like(intersections, device=device)

        return intersections_altered


if __name__ == '__main__':

    model = SimpleModel().to(device)
    optim = torch.optim.Adam(model.mlp.parameters(), lr=1e-3)
    epochs = 200

    for i in range(epochs):

        # create dummy sample and gt
        sample = torch.zeros((2, 250, 250), device=device).reshape((250**2, 2))     # reshape array of x/y coords to format [batchsize,2]
        gt = torch.zeros((3, 250, 250), device=device)

        optim.zero_grad()

        model_out = model(sample).to(device)
        loss = torch.nn.L1Loss()(model_out, gt)
        loss.backward()

        optim.step()

        print("Epoch {}/{} - Loss: {}".format(i, epochs, loss.item()))

Tags: toselffordevicenntorchgridctx