如何替换infs避免PyTorch中的nan梯度

2024-10-04 01:29:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要计算log(1 + exp(x)),然后对其使用自动微分。但是对于过大的x,它会输出inf,这是因为求幂:

>>> x = torch.tensor([0., 1., 100.], requires_grad=True)
>>> x.exp().log1p()
tensor([0.6931, 1.3133,    inf], grad_fn=<Log1PBackward>)

由于log(1 + exp(x)) ≈ x对于大的x,我想我可以用torch.where将{}替换为{}。但是当这样做的时候,我仍然得到nan,因为梯度值太大了。你知道为什么会发生这种情况吗?有没有其他方法可以让它起作用?在

^{pr2}$

Tags: logtrue情况torchnanwhereinffn
3条回答

我发现的一个解决方法是用反向对应的方法手动实现Log1PlusExp函数。但这并不能解释torch.where在问题中的不良行为。在

>>> class Log1PlusExp(torch.autograd.Function):
...     """Implementation of x ↦ log(1 + exp(x))."""
...     @staticmethod
...     def forward(ctx, x):
...         exp = x.exp()
...         ctx.save_for_backward(x)
...         return x.where(torch.isinf(exp), exp.log1p())
...     @staticmethod
...     def backward(ctx, grad_output):
...         x, = ctx.saved_tensors
...         return grad_output / (1 + (-x).exp())
... 
>>> log_1_plus_exp = Log1PlusExp.apply
>>> x = torch.tensor([0., 1., 100.], requires_grad=True)
>>> log_1_plus_exp(x)  # No infs
tensor([  0.6931,   1.3133, 100.0000], grad_fn=<Log1PlusExpBackward>)
>>> log_1_plus_exp(x).sum().backward()
>>> x.grad  # And no nans!
tensor([0.5000, 0.7311, 1.0000])

如果x>;=20,则函数输出约为x。 使用Pythorch方法torch.softplus公司. 这有助于解决问题。在

But for too large x, it outputs inf because of the exponentiation

这就是为什么x永远不要太大。理想情况下应在[-1,1]范围内。 如果不是这样,您应该规范化您的输入。在

相关问题 更多 >