带比较算子的Numpy广播;循环迭代

2024-10-04 05:29:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我用两种方法实现了循环迭代函数:

def Spin1(n, N) :           # n - current state, N - highest state
    value = n + 1   
    case1 = (value > N) 
    case2 = (value <= N)
    return case1 * 0 + case2 * value

def Spin2(n, N) :
    value = n + 1   
    if value > N :
        return 0
    else : return value

对于返回的结果,这些函数是相同的。然而,第二个功能不能对numpy数组进行广播。为了测试第一个函数,我运行以下命令:

^{pr2}$

神奇的是,它起作用了,真是太甜蜜了。所以我明白我想要什么:

[[0 0 0 0]
 [0 0 5 0]
 [0 0 0 0]]
[[1 1 1 1]
 [1 1 0 1]
 [1 1 1 1]]

现在使用第二个函数print Spin2(AR1,5),它失败,并出现以下错误:

if value > N
ValueError: The truth value of an array with more than one element is ambiguous. 
Use a.any() or a.all()

很清楚,因为if Array语句是无意义的。所以现在我只使用了第一种变体。但是当我看这些函数的时候,我有一种强烈的感觉,在第一个函数中,有更多的数学运算,所以我不会失去对它进行优化的希望。

问题:
1是否可以优化函数Spin1以减少操作如何在广播模式下使用函数Spin2(可能不会使代码太难看)?另外一个问题:用数组进行这种操作的最快方法是什么?
2。有没有一些标准的Python函数可以进行相同的计算(不能隐式地进行广播),它是如何正确地称为“循环增量”的?在


Tags: 方法函数returnifvaluedef数组current
3条回答

您的Spin1遵循面向数组语言(例如APL、MATLAB)中的一种成熟模式,用于将Spin2之类的函数“矢量化”。创建一个或多个布尔(或0/1数组)来表示数组元素可以采用的各种状态,然后通过乘法和求和构造输出。在

例如,为了避免被零除的问题,我使用了:

1/(x+(x==0))

另一种方法是使用布尔索引数组来选择应该更改的数组元素。在本例中,您希望返回value,但所选元素为“rollover”。在

^{pr2}$

在这种情况下,索引方法更简单,似乎更适合程序逻辑。可能更快,但我不能保证。最好把这两种方法都记在心里。在

我在这里提供了一些反馈作为答案,只是不想把问题搞砸。所以我对各种函数做了计时测试,结果发现在这种情况下,通过布尔掩码赋值是最快的变体(hpaulj的答案)。np.where慢了1.4倍,而{}慢了15倍。出于好奇,我想用循环来测试这一点,所以我设计了一个测试算法:

AR1 = numpy.zeros((rows, cols), dtype = numpy.uint32)
while d <= 100:
    Buf = numpy.zeros_like(AR1)
    r = 0
    c = 0
    while (r < rows) :
        while (c < cols) :
            temp = AR1[r, c] + 1
            if temp > 5 : 
                Buf[r, c] = 1
            else : Buf[r, c] = temp 
            c += 1
        r += 1
        c = 0
    AR1 = Buf
    d += 1

我不确定,但似乎所有上述功能的实现都非常简单。但它太慢了,几乎慢了300倍。我读过类似的问题,但还是不明白,为什么会这样?到底是什么导致了经济放缓。在这里,我特意设置了一个缓冲区,以避免对同一个元素执行读写函数,并且不进行内存清理。所以还有什么可以更简单的,我很困惑。不想打开一个新的问题,因为它已经被问了几次,所以可能有人会提出意见或有良好的链接澄清这一点?在

这里有一个numpy函数:np.where

In [590]: AR1
Out[590]: 
array([[0, 0, 0, 0],
       [0, 0, 5, 0],
       [0, 0, 0, 0]], dtype=uint32)

In [591]: np.where(AR1 >= 5, 0, 1)
Out[591]: 
array([[1, 1, 1, 1],
       [1, 1, 0, 1],
       [1, 1, 1, 1]])

所以,你可以定义:

^{pr2}$

NumPy还提供了一种将普通Python函数转换为ufuncs的方法:

def Spin2(n, N) :
    value = n + 1   
    if value > N :
        return 0
    else : return value

Spin2 = np.vectorize(Spin2)

以便您现在可以对数组调用Spin2

In [595]: Spin2(AR1, 5)
Out[595]: 
array([[1, 1, 1, 1],
       [1, 1, 0, 1],
       [1, 1, 1, 1]])

但是,np.矢量化主要提供语法糖分。对于每个数组元素仍然有一个Python函数调用,它使np.vectorizedufuncsno faster than equivalent code using Python for-loops。在

相关问题 更多 >