在numba的jitclass中索引多维numpy数组

2024-06-24 12:59:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图将一个小的多维数组插入到numba jitclass中的一个较大的数组中。小数组是由索引列表定义的大数组的特定位置。在

下面的MWE显示了没有numba的问题-一切正常

import numpy as np

class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution 1 using pure python
    def nonNumbaFunction1(self, idx, values):
        self.A[idx[:, None], idx] = values

    # solution 2 using pure python
    def nonNumbaFunction2(self, idx, values):
        self.A[np.ix_(idx, idx)] = values

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.nonNumbaFunction1(idx, values)
    print(f'A =\n{obj.A}')

    obj.nonNumbaFunction2(idx, values)
    print(f'A =\n{obj.A}')

函数nonNumbaFunction1nonNumbaFunction2在numba类中都不工作。所以我目前的解决方案是这样的,在我看来并不好

^{pr2}$

所以我的问题是:

  • 在numba中有人知道这个索引的解决方案吗?还是有其他矢量化的解决方案?在
  • 有没有更快的解决方案nonNumbaFunction1?在

知道插入的数组很小(4x4到10x10)可能很有用,但是这个索引出现在嵌套循环中,所以它必须非常快!后来我需要一个类似的三维对象索引。在


Tags: selfobjlendefnp数组解决方案values
1条回答
网友
1楼 · 发布于 2024-06-24 12:59:02

由于numba对索引支持的限制,我不认为您能比自己编写for循环更好。要使其跨维度通用,可以使用^{}装饰器进行专门化。像这样:

def set_2d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            target[idx[i], idx[j]] = values[i, j]

def set_3d(target, values, idx):
    for i in range(values.shape[0]):
        for j in range(values.shape[1]):
            for k in range(values.shape[2]):
                target[idx[i], idx[j], idx[k]] = values[i, j, l]

@numba.generated_jit
def set_nd(target, values, idx):
    if target.ndim == 2:
        return set_2d
    elif target.ndim == 3:
        return set_3d

然后,可以在jitclass中使用它

^{pr2}$

相关问题 更多 >