<p>我也在寻找解决办法。不幸的是,@Carbon的建议不起作用,因为函数<code>bar</code>的<code>numba.typeof</code>返回的类型与函数<code>baz</code>的类型不同,即使<code>bar</code>和<code>baz</code>的签名相同</p>
<p>例如:</p>
<pre class="lang-py prettyprint-override"><code>import numba
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def bar(a):
return 2 * a
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def baz(a):
return 3 * a
@numba.jit(
numba.int32(numba.typeof(bar), numba.int32),
nopython=True,
nogil=True,
)
def foo(fn, a):
return fn(a)
</code></pre>
<p><code>foo(bar, 2)</code>返回4</p>
<p><code>foo(baz, 2)</code>返回以下异常:</p>
<pre><code>Traceback (most recent call last):
File "test_numba.py", line 33, in <module>
print(foo(baz, 2))
File "<snip>\Python38\lib\site-packages\numba\core\dispatcher.py", line 656, in _explain_matching_error
raise TypeError(msg)
TypeError: No matching definition for argument type(s) type(CPUDispatcher(<function baz at 0x000001DFA8C2D1F0>)), int64
</code></pre>
<p>我发现的唯一解决办法是完全省略<code>foo</code>的函数签名,让numba来解决这个问题。我不知道这会给你的代码带来什么负面影响(如果有的话)</p>
<p>例如:</p>
<pre class="lang-py prettyprint-override"><code>import numba
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def bar(a):
return 2 * a
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def baz(a):
return 3 * a
@numba.jit(
nopython=True,
nogil=True,
)
def foo(fn, a):
return fn(a)
</code></pre>
<p><code>foo(bar, 2)</code>返回4</p>
<p><code>foo(baz, 2)</code>返回6</p>