为什么这段代码不能用Numba编译?

2024-09-30 06:33:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个示例代码来说明我的问题。如果您运行:

import numpy as np
from numba import jit


@jit(nopython=True)
def test():
    arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])

    arr2 = arr[:, 0, :]

    arr3 = arr2.argsort()

    print(arr3)

test()

它将在以下情况下失败:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of BoundFunction(array.argsort for array(int64, 2d, A)) with parameters ()
During: resolving callee type: BoundFunction(array.argsort for array(int64, 2d, A))
During: typing of call at /home/stark/Work/mmr6/test.py (41)


File "test.py", line 41:
def test():
    <source elided>

    arr3 = arr2.argsort()
    ^

argsort应在最后一个轴上进行argsort。基本上,它应该给我:

>>>
[[0 1 2]
 [0 1 2]]

我认为复制arr2数组(使用copy())可以解决这个问题,因为它会使数组在内存中连续(而不是视图),但是它失败了,因为消息中的arr2类型现在是array(int64, 2d, C),与预期的一样

为什么它会失败?我如何修复它


Tags: oftestimportdefnparrayjitarr
1条回答
网友
1楼 · 发布于 2024-09-30 06:33:14

不幸的是,这是目前已知的麻木的一个限制。见this issue。到目前为止,仅支持1D阵列。但是,在您的案例中有一个简单的解决方法:

import numpy as np
from numba import jit


@jit(nopython=True)
def test():
    arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])

    arr2 = arr[:, 0, :]

    arr3 = np.empty(arr2.shape, dtype=arr2.dtype)
    for i in range(arr2.shape[0]):
        arr3[i] = arr2[i, :].argsort()

    print(arr3)

test()

请注意,即使已实现,也不会更快。见this issue。事实上,对于任何给定的Numpy原语,Numba都没有理由更快。但是,您可以使用Numba手动编写自己版本的Numpy原语,有时由于算法专门化、并行性或数学优化(如快速数学),速度会有所提高。当您想要执行Numpy中尚未/直接提供的有效操作时,Numba通常非常有用,并且可以使用循环轻松实现此操作

实际上,您可以使用Numba的prange和JIT参数parallel=True来加快计算速度,前提是argsort尚未并行运行(假设它应该是顺序的)。这应该比大型阵列上的Numpy实现(不应该按顺序运行)快一点(在小型阵列上,生成多个线程的成本可能比实际计算要高)

相关问题 更多 >

    热门问题