将numpy.bmat与numb一起使用

2024-09-19 22:20:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试在我的numba优化python程序中使用np.bmat。为此,我必须手动定义jitted函数bmat,因为不支持numpy的本机函数:

@njit
def _bmat_2d(matrices):
    arr_rows = []
    for row in matrices:
        arr_rows.append(np.concatenate(row, axis=-1))
    return np.array(np.concatenate(arr_rows, axis=0))

(此代码或多或少是numpy代码的简化副本)

但是:

  1. numba只接受np.concatenate[1]输入中的元组
  2. numba不擅长将任意列表转换为元组[2]

你对此有什么想法吗

参考文献:


Tags: 函数代码httpsnumpygithubnprowsrow
1条回答
网友
1楼 · 发布于 2024-09-19 22:20:36

你认为下列方法行得通吗

import numpy as np
import numba as nb

@nb.njit
def _bmat_2d(m):
    out = np.hstack(m[0])
    for row in m[1:]:
        x = np.hstack(row)
        out = np.vstack((out, x))

    return out

A = np.random.randint(10, size=(3,2))
B = np.random.randint(10, size=(3,1))
C = np.random.randint(10, size=(3,3))
D = np.random.randint(10, size=(4,6))

a = np.bmat(((A, B, C), (D,)))
b = _bmat_2d(((A, B, C), (D,)))

print(np.allclose((a, b))  # True

请注意,您必须传入一个元组的元组,而不是列表列表,否则您将得到一个“反射列表”错误,因为当前版本中的Numba无法处理列表列表

相关问题 更多 >