我尝试使用花哨的索引而不是循环来加速Numpy中的函数。据我所知,我已经正确地实现了花哨的索引版本。问题是这两个函数(loop和fancy indexed)返回的结果不同。我不知道为什么。值得指出的是,如果使用较小的数组(例如20 x 20 x 20),函数返回的结果确实相同。在
下面我列出了重现错误所需的一切。{cdes>如果返回一个完整的数组},那么返回的结果应该是相同的。在
from numpy import *
def rms(data, axis=0):
return sqrt(mean(data ** 2, axis))
def find_maxdiff(data):
samples, channels, epochs = shape(data)
window_size = 50
maxdiff = zeros(epochs)
for epoch in xrange(epochs):
signal = rms(data[:, :, epoch], axis=1)
for t in xrange(window_size, alen(signal) - window_size):
amp_a = mean(signal[t-window_size:t], axis=0)
amp_b = mean(signal[t:t+window_size], axis=0)
the_diff = abs(amp_b - amp_a)
if the_diff > maxdiff[epoch]:
maxdiff[epoch] = the_diff
return maxdiff
def find_maxdiff_fancy(data):
samples, channels, epochs = shape(data)
window_size = 50
maxdiff = zeros(epochs)
signal = rms(data, axis=1)
for t in xrange(window_size, alen(signal) - window_size):
amp_a = mean(signal[t-window_size:t], axis=0)
amp_b = mean(signal[t:t+window_size], axis=0)
the_diff = abs(amp_b - amp_a)
maxdiff[the_diff > maxdiff] = the_diff
return maxdiff
data = random.random((600, 20, 100))
find_maxdiff(data) - find_maxdiff_fancy(data)
data = random.random((20, 20, 20))
find_maxdiff(data) - find_maxdiff_fancy(data)
问题是这条线:
左侧只选择maxdiff的一些元素,而右侧包含了_diff的所有元素。这样应该可以:
^{pr2}$或者简单地说:
至于为什么20x20x20大小似乎有效:这是因为您的窗口大小太大,所以没有执行任何操作。在
首先,在幻想中,如果我理解正确的话,你的信号现在是2D的-所以我认为显式地索引它会更清楚(例如amp_a=mean(signal[t-window_尺寸:t,:],轴=0)。类似于alen(signal)——这两种情况下都应该是样本,所以我认为使用它会更清楚。在
当您在
t
循环中实际执行某个操作时,这是错误的—当samples < window_lenght
时,就像20x20x20示例中那样,该循环永远不会执行。一旦循环被多次执行(即samples > 2 *window_length+1
),错误就会出现。不知道为什么-他们看起来和我一样。在相关问题 更多 >
编程相关推荐