擅长:python、mysql、java
<p>我发现的一个解决方法是用反向对应的方法手动实现<code>Log1PlusExp</code>函数。但这并不能解释<code>torch.where</code>在问题中的不良行为。在</p>
<pre class="lang-py prettyprint-override"><code>>>> class Log1PlusExp(torch.autograd.Function):
... """Implementation of x ↦ log(1 + exp(x))."""
... @staticmethod
... def forward(ctx, x):
... exp = x.exp()
... ctx.save_for_backward(x)
... return x.where(torch.isinf(exp), exp.log1p())
... @staticmethod
... def backward(ctx, grad_output):
... x, = ctx.saved_tensors
... return grad_output / (1 + (-x).exp())
...
>>> log_1_plus_exp = Log1PlusExp.apply
>>> x = torch.tensor([0., 1., 100.], requires_grad=True)
>>> log_1_plus_exp(x) # No infs
tensor([ 0.6931, 1.3133, 100.0000], grad_fn=<Log1PlusExpBackward>)
>>> log_1_plus_exp(x).sum().backward()
>>> x.grad # And no nans!
tensor([0.5000, 0.7311, 1.0000])
</code></pre>