我最近被介绍给Pythorch,开始浏览图书馆的文档和教程。 在“使用numpy和scipy创建扩展”教程中( http://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html),在“无参数示例”下,使用名为“BadFFTFunction”的numpy创建一个示例函数。在
功能说明如下:
"This layer doesn’t particularly do anything useful or mathematically correct.
It is aptly named BadFFTFunction"
函数及其用法如下:
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return torch.FloatTensor(result)
def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return torch.FloatTensor(result)
def incorrect_fft(input):
return BadFFTFunction()(input)
input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)
不幸的是,我也是最近才被介绍信号处理的,我不确定(可能是明显的)这个函数的错误在哪里。在
我想知道,如何修复这个函数,使它的正向和反向输出都是正确的? 如何固定BadFFTFunction以便在PyTorch中使用可微FFT函数?在
任何帮助都将不胜感激。 谢谢您。在
我认为错误是:首先,尽管函数名为FFT,但它只返回FFT输出的振幅/绝对值,而不是完整的复系数。另外,仅仅使用逆FFT来计算振幅梯度在数学上可能没有多大意义。在
有一个名为pytorch-fft的包试图在pytorch中提供一个FFT函数。您可以看到一些自动加载功能的实验性代码here。还要注意这个issue中的讨论。在
相关问题 更多 >
编程相关推荐