我正在用python创建一个应用程序来计算管道胶带的重叠(为分配器建模将产品应用到旋转的鼓上)。在
我有一个程序可以正常工作,但是非常慢。我正在寻找一个优化用来填充numpy数组的for
循环的解决方案。有人能帮我把下面的代码矢量化吗?在
import numpy as np
import matplotlib.pyplot as plt
# Some parameters
width = 264
bbddiam = 940
accuracy = 4 #2 points per pixel
drum = np.zeros(accuracy**2 * width * bbddiam).reshape((bbddiam * accuracy , width * accuracy))
# The "slow" function
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
"""Masks a half of the array"""
to_return = np.zeros(drum.shape)
for index, v in np.ndenumerate(to_return):
if upper == True:
if index[0] * coef + intercept > index[1]:
to_return[index] = 1
else:
if index[0] * coef + intercept <= index[1]:
to_return[index] = 1
return to_return
def get_band(drum, coef, intercept, bandwidth):
"""Calculate a ribbon path on the drum"""
to_return = np.zeros(drum.shape)
t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
to_return = t1 + t2
return np.where(to_return == 2, 1, 0)
single_band = get_band(drum, 1 / 10, 130, bandwidth=15)
# Visualize the result !
plt.imshow(single_band)
plt.show()
Numba为我的代码创造了奇迹,将运行时从5.8秒缩短到86毫秒(特别感谢@Maarten vd Sande):
^{pr2}$使用numpy的更好的解决方案仍然是受欢迎的;-)
这里根本不需要任何循环。您实际上有两个不同的}来重写它,而这两个循环会被多次计算。在
line_mask
函数。这两个都不需要显式地循环,但是只要在if
和else
中使用一对for
循环重写它,而不是在for
循环中使用if
和{真正的numpythonic要做的是正确地将代码矢量化,以便在整个数组上操作而不产生任何循环。以下是
line_mask
的矢量化版本:将},从而得到{}的结果称为{a1},这是numpy中矢量化的主要内容。在
r
和c
的形状设置为(m, 1)
和{更新的
^{pr2}$line_mask
的结果是一个布尔掩码(顾名思义),而不是一个浮点数组。这使得它变得更小,并且有望完全绕过浮动操作。现在可以重写get_band
以使用屏蔽而不是加法:程序的其余部分应该保持不变,因为这些函数保留所有接口。在
如果您愿意,您可以用三行代码重写程序的大部分内容(仍然有些易读):
相关问题 更多 >
编程相关推荐