回答此问题可获得 20 贡献值,回答如果被采纳可获得 50 分。
<p>我试图用<code>guvectorize</code>对一个返回元组的简单函数进行矢量化。显然,<code>numba</code>文档不包含<code>guvectorize</code>的任何工作示例,其中函数返回<code>tuple</code>。在</p>
<p>最初,我想做的是:</p>
<pre><code>z = (x+y, x-y)
</code></pre>
<p>然后我根据stackoverflow的答案将其更改为以下内容。在</p>
^{pr2}$
<p>不过,我还是发现了一些似乎很难破译的错误。我想要的是向量化一个函数,它接受samdimension的多个数组,并返回一个与输入数组具有相同维数的元组数组。例如,假设sample函数的输入数组是:</p>
<pre><code>a = array([[4, 7, 9],
[7, 1, 2]])
b = array([[5, 6, 6],
[2, 5, 6]])
</code></pre>
<p>则输出应为:</p>
<pre><code>c = array([[ (9, -1), (13, 1), (15, 3)],
[ (9, 5), (6, -4), (8, -4)]], dtype=object)
</code></pre>
<p>我的示例代码和错误如下所示:</p>
<pre><code>from numba import void, float64, UniTuple, guvectorize
@guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
z[:] = (x+y, x-y)
</code></pre>
<pre><code><ipython-input-24-6920fb0e2a76>:2: NumbaWarning:
Compilation is falling back to object mode WITHOUT looplifting enabled because Function "fun" failed type inference due to: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (tuple(array(float64, 1d, A) x 2), slice<a:b>, tuple(array(float64, 1d, C) x 2))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
All templates rejected with literals.
In definition 9:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at <ipython-input-24-6920fb0e2a76> (4)
File "<ipython-input-24-6920fb0e2a76>", line 4:
def fun(x, y, z):
z[:] = (x+y, x-y)
^
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
/home/user/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler.py:742: NumbaWarning: Function "fun" was compiled in object mode without forceobj=True.
File "<ipython-input-24-6920fb0e2a76>", line 3:
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
^
self.func_ir.loc))
/home/user/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler.py:751: NumbaDeprecationWarning:
Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.
For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit
File "<ipython-input-24-6920fb0e2a76>", line 3:
@nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
def fun(x, y, z):
^
warnings.warn(errors.NumbaDeprecationWarning(msg, self.func_ir.loc))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-24-6920fb0e2a76> in <module>
1 from numba.types import UniTuple
----> 2 @nb.guvectorize(['void(float64[:], float64[:], UniTuple(float64[:], 2))'], '(n), (n) -> (n)')
3 def fun(x, y, z):
4 z[:] = (x+y, x-y)
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/decorators.py in wrap(func)
178 for fty in ftylist:
179 guvec.add(fty)
--> 180 return guvec.build_ufunc()
181
182 return wrap
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/ufuncbuilder.py in build_ufunc(self)
304 for sig in self._sigs:
305 cres = self._cres[sig]
--> 306 dtypenums, ptr, env = self.build(cres)
307 dtypelist.append(dtypenums)
308 ptrlist.append(utils.longint(ptr))
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/ufuncbuilder.py in build(self, cres)
328 info = build_gufunc_wrapper(
329 self.py_func, cres, self.sin, self.sout,
--> 330 cache=self.cache, is_parfors=False,
331 )
332
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors)
501 else _GufuncWrapper)
502 return wrapcls(
--> 503 py_func, cres, sin, sout, cache, is_parfors=is_parfors,
504 ).build()
505
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
30 def _acquire_compile_lock(*args, **kwargs):
31 with self:
---> 32 return func(*args, **kwargs)
33 return _acquire_compile_lock
34
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in build(self)
454 def build(self):
455 wrapper_name = "__gufunc__." + self.fndesc.mangled_name
--> 456 wrapperlib = self._compile_wrapper(wrapper_name)
457 return _wrapper_info(
458 library=wrapperlib, env=self.env, name=wrapper_name,
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in _compile_wrapper(self, wrapper_name)
445 wrapperlib.enable_object_caching()
446 # Build wrapper
--> 447 self._build_wrapper(wrapperlib, wrapper_name)
448 # Cache
449 self.cache.save_overload(self.cres.signature, wrapperlib)
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in _build_wrapper(self, library, name)
399 self.sin + self.sout)):
400 ary = GUArrayArg(self.context, builder, arg_args,
--> 401 arg_steps, i, step_offset, typ, sym, sym_dim)
402 step_offset += len(sym)
403 arrays.append(ary)
~/.conda/envs/Py3DevLocal/lib/python3.7/site-packages/numba/npyufunc/wrappers.py in __init__(self, context, builder, args, steps, i, step_offset, typ, syms, sym_dim)
656 if syms:
657 raise TypeError("scalar type {0} given for non scalar "
--> 658 "argument #{1}".format(typ, i + 1))
659 self._loader = _ScalarArgLoader(dtype=typ, stride=core_step)
660
TypeError: scalar type tuple(array(float64, 1d, A) x 2) given for non scalar argument #3
</code></pre>