如何将Numba“@vectorize”ufunc与结构化Numpy数组一起使用?

2024-05-19 08:11:24 发布

您现在位置:Python中文网/ 问答频道 /正文

我无法让矢量化的ufunc运行。常规@njit工作正常,@vectorize documentation表明矢量化装饰器与njit相同。我现在运行的是Windows10,如果这有区别的话

演示程序如下。从下面的输出中,我们可以看到njit函数运行时没有事件,矢量化函数存在类型错误

import sys
import numpy
import numba

Structured = numpy.dtype([("a", numpy.int32), ("b", numpy.float64)])
numba_dtype = numba.from_dtype(Structured)

@numba.njit([numba.float64(numba_dtype)])
def jitted(x):
    x['b'] = 17.5
    return 18.

@numba.vectorize([numba.float64(numba_dtype)], target="cpu", nopython=True)
def vectorized(x):
    x['b'] = 17.5
    return 12.1

print('python version = ', sys.implementation.version)    
print('numpy version = ', numpy.__version__)
print('numba version = ', numba.__version__)
for struct in numpy.empty((3,), dtype=Structured):
    print(jitted(struct))

print(vectorized(numpy.empty((3,), dtype=Structured)))

输出是

python version = sys.version_info(major=3, minor=7, micro=1, releaselevel='final', serial=0)
numpy version = 1.17.3
numba version = 0.48.0
18.0
18.0
18.0
Traceback (most recent call last): File "scratch.py", line 49, in
print(vectorized(numpy.empty((3,), dtype=Structured))) TypeError: ufunc 'vectorized' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''


Tags: theimportnumpyversionsys矢量化emptyprint

热门问题