使用np.min.最小值在numba函数中使用列表输入

2024-10-01 07:28:30 发布

您现在位置:Python中文网/ 问答频道 /正文

在这里使用np.min有什么问题?为什么numba不喜欢在那个函数中使用列表,有没有其他方法可以让np.min工作?在

from numba import njit
import numpy as np

@njit
def availarray(length):
    out=np.ones(14)
    if length>0:
        out[0:np.min([int(length),14])]=0
    return out

availarray(3)

该函数与min配合使用很好,但是np.min应该更快。。。在


Tags: 方法函数fromimportnumpy列表defas
2条回答

问题是np.min的numba版本需要一个array作为输入。在

from numba import njit
import numpy as np

@njit
def test_numba_version_of_numpy_min(inp):
    return np.min(inp)

>>> test_numba_version_of_numpy_min(np.array([1, 2]))  # works
1

>>> test_numba_version_of_numpy_min([1, 2]) # doesn't work
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function amin at 0x000001B5DBDEE598>) with argument(s) of type(s): (reflected list(int64))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.

更好的解决方案是只使用numba版的Pythonmin

^{pr2}$

因为np.min和{}实际上都是这些函数的Numba版本(至少在njitted函数中)min在这种情况下也应该快得多。但是这不太可能引起注意,因为数组的分配和将一些元素设置为零将是这里主要的运行时贡献者。在

请注意,这里甚至不需要min调用,因为即使使用更大的停止索引,切片也会隐式地停止在数组的末尾:

from numba import njit
import numpy as np

@njit
def availarray(length):
    out = np.ones(14)
    if length > 0:
        out[0:length] = 0
    return out

要使您的代码与numba一起工作,您必须将np.min应用于NumPy数组,这意味着您必须将列表[int(length),14]转换为NumPy数组,如下所示

from numba import njit
import numpy as np

@njit
def availarray(length):
    out=np.ones(14)
    if length>0:
        out[0:np.min(np.array([int(length),14]))]=0   
    return out

availarray(3)
# array([0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

相关问题 更多 >