我在2000万个不同的参数组合上运行下面的函数compare
,其中sample
是由100个1和0组成的一维数组
compare
将其他数组与sample
一起使用,并使用它们执行一些点积,对这些点积求幂,然后将它们相互比较。其他数组保持不变。你知道吗
在我的笔记本电脑上,运行2000万个组合大约需要一个小时。你知道吗
我在想办法让它走得更快。我对改进下面的代码和使用像dask这样利用并行处理的库持开放态度。你知道吗
备注:
compare
中每一行上的注释显示了该行在我的机器上花费的时间的一个非常粗略的估计。它们是%%timeit在函数外部自己联机的结果。你知道吗compare
的输入实际上不是随机生成的def compare(sample, competition_exp_dot, preferences): # 140 µs
sample_exp_dot = np.exp(preferences @ sample) #30.3 µs
all_competitors = np.append(sample_exp_dot.reshape(-1, 1), competition_exp_dot, 1) # 5 µs
all_results = all_products/all_competitors.sum(axis=1)[:,None] #27.4 µs
return all_results.mean(axis=0) #20.6 µs
#these inputs to the above function stay the same
preferences = np.random.random((1000,100))
competition = np.array([np.random.randint(0,2,100), np.random.randint(0,2,100)])
competition_exp_dot = np.exp(preferences @ competition.T)
# the function is run with 20,000,000 variations of sample
population = np.random.randint(0,2,(20000000,100))
result = [share_calc(sample, competition_exp_dot, preferences) for sample in population]
我按照洛克林先生的建议实现了numba。结果在我的机器上快了4倍。你知道吗
修改的Numba版本
可比Numpy版本
时间比较
设置:
注意事项
Numba不支持可选参数,比如axis fornp.平均值并返回一个打字机错误。在我的numba代码中,我使用了
np_mean
的callbelow版本。你知道吗记入joelrich
您可以考虑以下几点:
有许多方法可以加速这样的简单数组编程代码:
你也可以做任何上述的混合。你知道吗
相关问题 更多 >
编程相关推荐