当包含JIT函数时,如何指定numba JIT类的字段?

2024-10-06 12:28:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我想创建一个具有包含任何jitted函数的属性的numba jitclass

# simple jitted functions defined in another file
@njit
def my_function(x):
    x = x + 1
    return x

@njit
def another_function(x):
    x = x * 2
    return x

spec = [('attribute', ???),
        ('value', float32)]

@jitclass(spec)
class Myclass:
    def __init__(self, fun):
        self.attribute = fun

    def class_fun(self, x):
       value = self.attribute(x)
       return value

当我用numba.typeof(my_function)替换???并用an_object = Myclass(fun=my_function)创建Myclass的实例时,一切正常。但是,我只能传递确切的函数my_function。当我用另一个jitted函数创建Myclass对象时

new_object = Myclass(fun=another_function)
new_object.class_fun(2)

我得到以下错误:

Failed in nopython mode pipeline (step: nopython mode backend)
Cannot cast type(CPUDispatcher(<function another_function at 0x7fb1e3d8ce50>))
to type(CPUDispatcher(<function my_function at 0x7fb1e404de50>))

这对我来说很有意义,因为我为我的函数定义了字段类型。我不知道如何以一种通用的方式定义字段类型attribute,以便传递任何函数

有人知道在创建jitclass Myclass的对象时可以传递任何jitted函数的方法吗


Tags: 函数selfreturnobjectvaluemydefmyclass