使用“numba”guvectorize将返回元组的函数矢量化`

2024-10-03 11:22:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图用guvectorize对一个返回元组的简单函数进行矢量化。显然,numba文档不包含guvectorize的任何工作示例,其中函数返回tuple。在

最初,我想做的是:

z = (x+y, x-y)

然后我根据stackoverflow的答案将其更改为以下内容。在

^{pr2}$

不过,我还是发现了一些似乎很难破译的错误。我想要的是向量化一个函数,它接受samdimension的多个数组,并返回一个与输入数组具有相同维数的元组数组。例如,假设sample函数的输入数组是:

a = array([[4, 7, 9],
           [7, 1, 2]])
b = array([[5, 6, 6],
           [2, 5, 6]])

则输出应为:

c = array([[ (9, -1), (13, 1), (15, 3)],
           [ (9, 5),  (6, -4),  (8, -4)]], dtype=object)

我的示例代码和错误如下所示:

from numba import void, float64, UniTuple, guvectorize
@guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)') 
def fun(x, y, z): 
    z[:] = (x+y, x-y)
<ipython-input-24-6920fb0e2a76>:2: NumbaWarning: 
Compilation is falling back to object mode WITHOUT looplifting enabled because Function "fun" failed type inference due to: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (tuple(array(float64, 1d, A) x 2), slice<a:b>, tuple(array(float64, 1d, C) x 2))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at <ipython-input-24-6920fb0e2a76> (4)

File "<ipython-input-24-6920fb0e2a76>", line 4:
def fun(x, y, z):
    z[:] = (x+y, x-y)
    ^

  @nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
/home/user/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler.py:742: NumbaWarning: Function "fun" was compiled in object mode without forceobj=True.

File "<ipython-input-24-6920fb0e2a76>", line 3:
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
^

  self.func_ir.loc))
/home/user/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler.py:751: NumbaDeprecationWarning: 
Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.

For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit

File "<ipython-input-24-6920fb0e2a76>", line 3:
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
^

  warnings.warn(errors.NumbaDeprecationWarning(msg, self.func_ir.loc))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-6920fb0e2a76> in <module>
      1 from numba.types import UniTuple
----> 2 @nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
      3 def fun(x, y, z):
      4     z[:] = (x+y, x-y)

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/decorators.py in wrap(func)
    178         for fty in ftylist:
    179             guvec.add(fty)
--> 180         return guvec.build_ufunc()
    181 
    182     return wrap

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/ufuncbuilder.py in build_ufunc(self)
    304         for sig in self._sigs:
    305             cres = self._cres[sig]
--> 306             dtypenums, ptr, env = self.build(cres)
    307             dtypelist.append(dtypenums)
    308             ptrlist.append(utils.longint(ptr))

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/ufuncbuilder.py in build(self, cres)
    328         info = build_gufunc_wrapper(
    329             self.py_func, cres, self.sin, self.sout,
--> 330             cache=self.cache, is_parfors=False,
    331         )
    332 

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors)
    501                else _GufuncWrapper)
    502     return wrapcls(
--> 503         py_func, cres, sin, sout, cache, is_parfors=is_parfors,
    504     ).build()
    505 

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in build(self)
    454     def build(self):
    455         wrapper_name = "__gufunc__." + self.fndesc.mangled_name
--> 456         wrapperlib = self._compile_wrapper(wrapper_name)
    457         return _wrapper_info(
    458             library=wrapperlib, env=self.env, name=wrapper_name,

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in _compile_wrapper(self, wrapper_name)
    445                 wrapperlib.enable_object_caching()
    446                 # Build wrapper
--> 447                 self._build_wrapper(wrapperlib, wrapper_name)
    448                 # Cache
    449                 self.cache.save_overload(self.cres.signature, wrapperlib)

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in _build_wrapper(self, library, name)
    399                                            self.sin + self.sout)):
    400             ary = GUArrayArg(self.context, builder, arg_args,
--> 401                              arg_steps, i, step_offset, typ, sym, sym_dim)
    402             step_offset += len(sym)
    403             arrays.append(ary)

~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in __init__(self, context, builder, args, steps, i, step_offset, typ, syms, sym_dim)
    656             if syms:
    657                 raise TypeError("scalar type {0} given for non scalar "
--> 658                                 "argument #{1}".format(typ, i + 1))
    659             self._loader = _ScalarArgLoader(dtype=typ, stride=core_step)
    660 

TypeError: scalar type tuple(array(float64, 1d, A) x 2) given for non scalar argument #3

Tags: inpybuildselflibpackagessitewrapper
2条回答

这似乎起到了预期的效果:

@guvectorize([void(float64[:], float64[:], float64[:], float64[:])], '(n), (n) -> (n), (n)')
def fun(x, y, addition, subtraction):
    addition[:] = x + y
    subtraction[:] = x - y

例如:

^{pr2}$

下面是一个Numba示例,返回2个二维NumPy数组的元组。 在本例中,我认为您可以在NumPy中使用sum和减法(如果有两个数组可以的话),但是我在这里添加一个使用Numba的有效示例。我用下面的方法来应用decorator,因为我觉得它很方便,但是如果您希望改回典型的方式,那么它是完全等效的。在

import numpy as np

try:
    from numba import jit, prange
except ImportError:
    numba_opt = False
else:
    numba_opt = True

a = np.array([[4, 7, 9],
             [7, 1, 2]], dtype=float)
b = np.array([[5, 6, 6],
             [2, 5, 6]], dtype=float)

def numba_function(a: np.ndarray, b: np.ndarray):
    l0 = np.shape(a)[0]
    l1 = np.shape(a)[1]
    p = np.zeros_like(a)
    m = np.zeros_like(a)
    for i in range(l0):
        for j in range(l1):
            p[i, j] = a[i, j] + b[i, j]
            m[i, j] = a[i, j] - b[i, j]
    return(p, m)

if numba_opt:
    fun_rec = jit(signature_or_function='UniTuple(float64[:,:],2)(float64[:,:],float64[:,:])',
                  nopython=True, parallel=False, cache=True, fastmath=True, nogil=True)(numba_function)


p, m = fun_rec(a, b)
print(p)
print(m)

相关问题 更多 >