要复制此问题,请参见下文。MLP预测某些东西,forward函数使用此预测执行交点并返回交点坐标。然后我们用坐标做更多的事情,但这并不重要。在向后传球中,我想计算交叉点坐标w.r.t.的梯度,MLP预测。这就是它失败的地方
非常感谢您的帮助
import torch
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Intersector(torch.autograd.Function):
@staticmethod
def forward(ctx, *input):
(mlp_pred, other_stuff) = input[:2]
found_intersections = torch.zeros((3, 250, 250), device=device)
# not differentiable intersection routine
# (just a dummy, doesn't really interesect anything, normally this would use mlp_pred)
with torch.no_grad():
for k in range(100):
other_stuff += i
if i > 20:
found_intersections += 15.0
break
ctx.mlp_prediction = mlp_pred # save stuff for backward pass
ctx.found_intersections = found_intersections
return found_intersections
@staticmethod
def backward(ctx, grad_output):
# grad_output contains grad of loss w.r.t the found intersections
# now: compute gradient of loss w.r.t input, i.e., w.r.t. mlp_pred
mlp_pred = ctx.mlp_prediction
intersections = ctx.found_intersections
with torch.enable_grad():
# get mlp prediction values at intersection coordinates by sampling the prediction at these values
grid_x = intersections[0, :, :].unsqueeze(dim=0)
grid_y = intersections[1, :, :].unsqueeze(dim=0)
sampling_grid = torch.cat((grid_x, grid_y), dim=0).permute(1, 2, 0).unsqueeze(dim=0)
mlp_pred_at_intersections = F.grid_sample(mlp_pred, sampling_grid, mode='bilinear', padding_mode='zeros')
pred_sum = mlp_pred_at_intersections.sum() # sum bc autograd needs scalar
### THIS IS WHERE THE OOM ERROR OCCURS
gradient_wrt_mlp_pred = torch.autograd.grad(pred_sum, mlp_pred, only_inputs=True, retain_graph=True)
grad_for_backprop = gradient_wrt_mlp_pred[0].to(device) * grad_output
# grad. for other_stuff not needed, return None
return grad_for_backprop, None
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
# extremely simple MLP
self.mlp = torch.nn.Sequential(
torch.nn.Linear(in_features=2, out_features=1),
torch.nn.LeakyReLU(),
)
self.calc_intersections = Intersector.apply
self.other_stuff = torch.zeros((3, 250, 250), device=device)
def forward(self, fw_input):
# get MLP prediction
mlp_prediction = self.mlp(fw_input).reshape((1, 1, 250, 250))
# get intersection points w.r.t mlp prediction
inputs_to_intersection = [mlp_prediction, self.other_stuff]
intersections = self.calc_intersections(*inputs_to_intersection)
# do some differentiable stuff w/ predictions
intersections_altered = intersections + torch.ones_like(intersections, device=device)
return intersections_altered
if __name__ == '__main__':
model = SimpleModel().to(device)
optim = torch.optim.Adam(model.mlp.parameters(), lr=1e-3)
epochs = 200
for i in range(epochs):
# create dummy sample and gt
sample = torch.zeros((2, 250, 250), device=device).reshape((250**2, 2)) # reshape array of x/y coords to format [batchsize,2]
gt = torch.zeros((3, 250, 250), device=device)
optim.zero_grad()
model_out = model(sample).to(device)
loss = torch.nn.L1Loss()(model_out, gt)
loss.backward()
optim.step()
print("Epoch {}/{} - Loss: {}".format(i, epochs, loss.item()))
目前没有回答
相关问题 更多 >
编程相关推荐