将for循环矢量化以计算管道胶带重叠

2024-09-30 20:17:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在用python创建一个应用程序来计算管道胶带的重叠(为分配器建模将产品应用到旋转的鼓上)。在

我有一个程序可以正常工作,但是非常慢。我正在寻找一个优化用来填充numpy数组的for循环的解决方案。有人能帮我把下面的代码矢量化吗?在

import numpy as np
import matplotlib.pyplot as plt

# Some parameters
width = 264
bbddiam = 940
accuracy = 4 #2 points per pixel

drum = np.zeros(accuracy**2 * width * bbddiam).reshape((bbddiam * accuracy , width * accuracy))

# The "slow" function
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
    """Masks a half of the array"""
    to_return = np.zeros(drum.shape)
    for index, v in np.ndenumerate(to_return):
        if upper == True:
            if index[0] * coef + intercept > index[1]:
                to_return[index] = 1
        else:
            if index[0] * coef + intercept <= index[1]:
                to_return[index] = 1
    return to_return


def get_band(drum, coef, intercept, bandwidth):
    """Calculate a ribbon path on the drum"""
    to_return = np.zeros(drum.shape)
    t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
    t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
    to_return = t1 + t2
    return np.where(to_return == 2, 1, 0)

single_band = get_band(drum, 1 / 10, 130, bandwidth=15)

# Visualize the result !
plt.imshow(single_band)
plt.show()

Numba为我的代码创造了奇迹,将运行时从5.8秒缩短到86毫秒(特别感谢@Maarten vd Sande):

^{pr2}$

使用numpy的更好的解决方案仍然是受欢迎的;-)


Tags: tonumpybandindexreturnnppltwidth
1条回答
网友
1楼 · 发布于 2024-09-30 20:17:26

这里根本不需要任何循环。您实际上有两个不同的line_mask函数。这两个都不需要显式地循环,但是只要在ifelse中使用一对for循环重写它,而不是在for循环中使用if和{}来重写它,而这两个循环会被多次计算。在

真正的numpythonic要做的是正确地将代码矢量化,以便在整个数组上操作而不产生任何循环。以下是line_mask的矢量化版本:

def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
    """Masks a half of the array"""
    r = np.arange(drum.shape[0]).reshape(-1, 1)
    c = np.arange(drum.shape[1]).reshape(1, -1)
    comp = c.__lt__ if upper else c.__ge__
    return comp(r * coef + intercept)

rc的形状设置为(m, 1)和{},从而得到{}的结果称为{a1},这是numpy中矢量化的主要内容。在

更新的line_mask的结果是一个布尔掩码(顾名思义),而不是一个浮点数组。这使得它变得更小,并且有望完全绕过浮动操作。现在可以重写get_band以使用屏蔽而不是加法:

^{pr2}$

程序的其余部分应该保持不变,因为这些函数保留所有接口。在

如果您愿意,您可以用三行代码重写程序的大部分内容(仍然有些易读):

coeff = 1/10
intercept = 130
bandwidth = 15

r, c = np.ogrid[:drum.shape[0], :drum.shape[1]]
check = r * coeff + intercept
single_band = ((check + bandwidth / 2 > c) & (check - bandwidth /  2 <= c))

相关问题 更多 >