理解cdist()函数

2024-09-28 20:50:37 发布

您现在位置:Python中文网/ 问答频道 /正文

这个new_cdist()函数实际上做什么?更具体地说:

  1. 当AdderNet纸张在其反向传播方程中不使用sqrt()运算时,为什么会有sqrt()运算
  2. needs_input_grad[]是如何使用的
def new_cdist(p, eta):
    class cdist(torch.autograd.Function):
        @staticmethod
        def forward(ctx, W, X):
            ctx.save_for_backward(W, X)
            out = -torch.cdist(W, X, p)
            return out

        @staticmethod
        def backward(ctx, grad_output):
            W, X = ctx.saved_tensors
            grad_W = grad_X = None
            if ctx.needs_input_grad[0]:
                _temp1 = torch.unsqueeze(X, 2).expand(X.shape[0], X.shape[1], W.shape[0]).permute(1, 0, 2)
                _temp2 = torch.unsqueeze(W.transpose(0, 1), 1)
                _temp = torch.cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
                grad_W = torch.matmul(grad_output, _temp)
                # print('before norm: ', torch.norm(grad_W))
                grad_W = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W) * grad_W
                print('after norm: ', torch.norm(grad_W))
            if ctx.needs_input_grad[1]:
                _temp1 = torch.unsqueeze(W, 2).expand(W.shape[0], W.shape[1], X.shape[0]).permute(1, 0, 2)
                _temp2 = torch.unsqueeze(X.transpose(0, 1), 1)
                _temp = torch.cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
                _temp = torch.nn.functional.hardtanh(_temp, min_val=-1., max_val=1.)
                grad_X = torch.matmul(grad_output.transpose(0, 1), _temp)
            return grad_W, grad_X
    return cdist().apply

我的意思是,它似乎与anew type of back-propagation equation and adaptive learning rate有关


Tags: norminputdeftorchsqrttempctxtranspose
1条回答
网友
1楼 · 发布于 2024-09-28 20:50:37

实际上,AdderNet论文确实使用了sqrt。它位于自适应学习率计算中(算法1,第6行)。更具体地说,你可以看到等式12:

enter image description here

这一行写的是什么:

# alpha_l = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W)
grad_W = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W) * grad_W

sqrt()来自等式13:

adaptive learning rate

where k denotes the number of elements in F_l to average the l2-norm, and η is a hyper-parameter to control the learning rate of adder filters.


关于第二个问题:needs_input_grad只是一个变量,用于检查输入是否真的需要梯度^在这种情况下,{}指的是{},而{}指的是{}。你可以阅读更多关于它的内容

相关问题 更多 >