如何在PyTorch中正确使用Numpy的FFT函数?

2024-09-30 05:31:42 发布

您现在位置:Python中文网/ 问答频道 /正文

我最近被介绍给Pythorch,开始浏览图书馆的文档和教程。 在“使用numpy和scipy创建扩展”教程中( http://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html),在“无参数示例”下,使用名为“BadFFTFunction”的numpy创建一个示例函数。在

功能说明如下:

"This layer doesn’t particularly do anything useful or mathematically correct.

It is aptly named BadFFTFunction"

函数及其用法如下:

from numpy.fft import rfft2, irfft2

class BadFFTFunction(Function):

    def forward(self, input):
        numpy_input = input.numpy()
        result = abs(rfft2(numpy_input))
        return torch.FloatTensor(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return torch.FloatTensor(result)

def incorrect_fft(input):
    return BadFFTFunction()(input)

input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)

不幸的是,我也是最近才被介绍信号处理的,我不确定(可能是明显的)这个函数的错误在哪里。在

我想知道,如何修复这个函数,使它的正向和反向输出都是正确的? 如何固定BadFFTFunction以便在PyTorch中使用可微FFT函数?在

任何帮助都将不胜感激。 谢谢您。在


Tags: 函数selffftnumpy示例inputreturndef
1条回答
网友
1楼 · 发布于 2024-09-30 05:31:42

我认为错误是:首先,尽管函数名为FFT,但它只返回FFT输出的振幅/绝对值,而不是完整的复系数。另外,仅仅使用逆FFT来计算振幅梯度在数学上可能没有多大意义。在

有一个名为pytorch-fft的包试图在pytorch中提供一个FFT函数。您可以看到一些自动加载功能的实验性代码here。还要注意这个issue中的讨论。在

相关问题 更多 >

    热门问题