无效使用参数类型为的函数

2024-09-28 23:16:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用的是Numba nonpython模式和一些NumPy函数

@njit
def invert(W, copy=True):
    '''
    Inverts elementwise the weights in an input connection matrix.
    In other words, change the from the matrix of internode strengths to the
    matrix of internode distances.

    If copy is not set, this function will *modify W in place.*

    Parameters
    ----------
    W : np.ndarray
        weighted connectivity matrix
    copy : bool

    Returns
    -------
    W : np.ndarray
        inverted connectivity matrix
    '''

    if copy:
        W = W.copy()
    E = np.where(W)
    W[E] = 1. / W[E]
    return W

在这个函数中,W是一个矩阵。但是我犯了以下错误。它可能与W[E] = 1. / W[E]线有关

File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)
  File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
    raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, A), tuple(array(int64, 1d, C) x 2))

那么,使用NumPy和Numba的正确方法是什么?我知道NumPy在矩阵计算方面做得很好。在这种情况下,NumPy的速度是否足够快,以至于Numba不再提供加速


Tags: ofthe函数innumpynpfunction矩阵
1条回答
网友
1楼 · 发布于 2024-09-28 23:16:16

正如FBruzzesi在评论中提到的,代码没有编译的原因是您使用了“奇特的索引”,因为W[E]中的Enp.where的输出,是数组的元组。(这就解释了一个稍微隐晦的错误消息:Numba不知道如何使用getitem,也就是说,当其中一个输入是元组时,它不知道如何在括号中查找某些内容。)

Numbaactually supports fancy indexing (also called "advanced indexing") on a single dimension,只是没有多个维度。在您的例子中,这允许进行一个简单的修改:首先使用ravel几乎无成本地将数组变成一维,然后应用转换,然后使用廉价的reshape返回

@njit
def invert2(W, copy=True):
    if copy:
        W = W.copy()
    Z = W.ravel()
    E = np.where(Z)
    Z[E] = 1. / Z[E]
    return Z.reshape(W.shape)

但这仍然比需要的慢,因为它通过不必要的中间数组传递计算,而不是在遇到非零值时立即修改数组。简单地做一个循环会更快:

@njit 
def invert3(W, copy=True): 
    if copy: 
        W = W.copy() 
    Z = W.ravel() 
    for i in range(len(Z)): 
        if Z[i] != 0: 
            Z[i] = 1/Z[i] 
    return Z.reshape(W.shape) 

无论W的维度如何,此代码都可以工作。如果我们知道W是二维的,那么我们可以直接在这两个维度上迭代,但是因为这两个维度的性能相似,所以我选择更一般的路径

在我的计算机上,假设一个300×300的数组W,其中大约一半的条目是0,其中invert是您没有编译的原始函数,则计时为:

In [80]: %timeit invert(W)                                                                   
2.67 ms ± 49.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [81]: %timeit invert2(W)                                                                  
519 µs ± 24.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [82]: %timeit invert3(W)                                                                  
186 µs ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

因此,Numba给了我们相当大的加速(在它已经运行一次以消除编译时间之后),特别是在代码以Numba可以利用的高效循环方式重写之后

相关问题 更多 >