我有一个使用for循环的函数,我想用numpy来提高速度。但这似乎并没有做到这一点,因为凹凸不平的版本似乎慢了2倍。代码如下:
import numpy as np
import itertools
import timeit
def func():
sample = np.random.random_sample((100, 2))
disc1 = 0
disc2 = 0
n_sample = len(sample)
dim = sample.shape[1]
for i in range(n_sample):
prod = 1
for k in range(dim):
sub = np.abs(sample[i, k] - 0.5)
prod *= 1 + 0.5 * sub - 0.5 * sub ** 2
disc1 += prod
for i, j in itertools.product(range(n_sample), range(n_sample)):
prod = 1
for k in range(dim):
a = 0.5 * np.abs(sample[i, k] - 0.5)
b = 0.5 * np.abs(sample[j, k] - 0.5)
c = 0.5 * np.abs(sample[i, k] - sample[j, k])
prod *= 1 + a + b - c
disc2 += prod
c2 = (13 / 12) ** dim - 2 / n_sample * disc1 + 1 / (n_sample ** 2) * disc2
def func_numpy():
sample = np.random.random_sample((100, 2))
disc1 = 0
disc2 = 0
n_sample = len(sample)
dim = sample.shape[1]
disc1 = np.sum(np.prod(1 + 0.5 * np.abs(sample - 0.5) - 0.5 * np.abs(sample - 0.5) ** 2, axis=1))
for i, j in itertools.product(range(n_sample), range(n_sample)):
disc2 += np.prod(1 + 0.5 * np.abs(sample[i] - 0.5) + 0.5 * np.abs(sample[j] - 0.5) - 0.5 * np.abs(sample[i] - sample[j]))
c2 = (13 / 12) ** dim - 2 / n_sample * disc1 + 1 / (n_sample ** 2) * disc2
print('Normal function time: ' , timeit.repeat('func()', number=20, repeat=5, setup="from __main__ import func"))
print('numpy function time: ', timeit.repeat('func_numpy()', number=20, repeat=5, setup="from __main__ import func_numpy"))
定时输出为:
^{pr2}$我错过了什么?我知道瓶颈是itertools部分,因为我以前有一个100x100x2的循环,而不是100x2的循环。 你有别的办法吗?在
正如我在评论中提到的,你的解决方案并不是真正的最优,比较不理想的方法也没有实际意义。在
一方面,迭代或索引NumPy数组的单个元素非常慢。我最近回答了一个包含很多细节的问题(如果你感兴趣,你可以看看:"convert np array to a set takes too long")。因此,Python方法只需将
array
转换为list
,就可以更快:我还将
np.abs
调用替换为正常的abs
。正常的abs
开销更低!也改变了其他部分。最后,这比你最初的“正常”方法快了10-20倍。在我还没有时间检查NumPy方法,@Divarkar已经包含了一个非常好和优化的方法。比较两种方法:
^{pr2}$因此,一个优化的NumPy方法绝对可以击败“优化的”Python方法。它几乎快了100倍。如果您想要更快,您可以在纯python代码稍微修改过的版本上使用numba:
这几乎是一个因素8-10快于NumPy方法。在
我们当然可以在这里用纽托里看看。在
仔细看一下循环部分,我们将沿着输入数据的第一个轴
samples
进行两次循环启动:一旦我们让^{} 处理这些操作,我们就可以将这些迭代转换为矢量化操作。在
现在,为了得到一个完全矢量化的解决方案,我们需要更多的内存空间,特别是}是输入数据的形状。在
(N,N,M)
,其中{这里另一个值得注意的方面是,在每次迭代中,我们没有做太多的工作,因为我们对每一行执行操作,并且每一行只包含给定示例的
2
元素。因此,我们可以沿着M
运行一个循环,这样在每次迭代中,我们将计算prod
并累加。因此,对于给定的样本,它只是两个循环迭代。在脱离循环,我们将得到累加的
prod
,它只需要disc2
的求和作为最终输出。在下面是实现上述想法的一个实现-
^{pr2}$运行时测试
原始方法的循环部分的精简版本和作为方法的修改版本如下所示:
时间安排和验证-
关于
900x
加速!好吧,这应该是足够的激励,希望能尽可能地把事情矢量化。在相关问题 更多 >
编程相关推荐