Numba不能使用转置或处理3d阵列

2024-10-03 13:27:13 发布

您现在位置:Python中文网/ 问答频道 /正文

在处理矩阵中每列的成对距离时,我尝试使用numba来提高计算效率。简化代码显示如下

### toy example
XX = np.mat(np.random.normal(size = (10, 6)))

### not use njit
def entrymat(XX):
    nrow, ncol = XX.shape
    entrydist = np.zeros((ncol, nrow, nrow))
    for d in range(ncol):
        entrydist[d] = np.power(XX[:, d].T, 2) + np.power(XX[:, d], 2) - 2 * XX[:, d] * XX[:, d].T
    return entrydist
entrymat(XX)[0]
Out[235]: 
array([[0.00000000e+00, 2.58320218e+00, 1.55410224e-01, 1.78789065e-02,
        1.04846099e-01, 1.77814001e+00, 1.29929461e+00, 2.47598273e-01,
        1.32994371e+00, 9.80701480e-01],
       [2.58320218e+00, 0.00000000e+00, 1.47140125e+00, 2.17126797e+00,
        1.64720502e+00, 7.49473659e-02, 2.18433107e-01, 1.23130504e+00,
        7.62017353e+00, 6.74720397e+00],
       [1.55410224e-01, 1.47140125e+00, 0.00000000e+00, 6.78649425e-02,
        4.95919632e-03, 8.82187044e-01, 5.55986489e-01, 1.06856550e-02,
        2.39461044e+00, 1.91690883e+00],
       [1.78789065e-02, 2.17126797e+00, 6.78649425e-02, 0.00000000e+00,
        3.61332368e-02, 1.43941718e+00, 1.01234592e+00, 1.32408981e-01,
        1.65622455e+00, 1.26341143e+00],
       [1.04846099e-01, 1.64720502e+00, 4.95919632e-03, 3.61332368e-02,
        0.00000000e+00, 1.01943288e+00, 6.65964658e-01, 3.02040080e-02,
        2.18162154e+00, 1.72686723e+00],
       [1.77814001e+00, 7.49473659e-02, 8.82187044e-01, 1.43941718e+00,
        1.01943288e+00, 0.00000000e+00, 3.74821648e-02, 6.98689833e-01,
        6.18368193e+00, 5.39992046e+00],
       [1.29929461e+00, 2.18433107e-01, 5.55986489e-01, 1.01234592e+00,
        6.65964658e-01, 3.74821648e-02, 0.00000000e+00, 4.12515343e-01,
        5.25829799e+00, 4.53762330e+00],
       [2.47598273e-01, 1.23130504e+00, 1.06856550e-02, 1.32408981e-01,
        3.02040080e-02, 6.98689833e-01, 4.12515343e-01, 0.00000000e+00,
        2.72522097e+00, 2.21383512e+00],
       [1.32994371e+00, 7.62017353e+00, 2.39461044e+00, 1.65622455e+00,
        2.18162154e+00, 6.18368193e+00, 5.25829799e+00, 2.72522097e+00,
        0.00000000e+00, 2.65455727e-02],
       [9.80701480e-01, 6.74720397e+00, 1.91690883e+00, 1.26341143e+00,
        1.72686723e+00, 5.39992046e+00, 4.53762330e+00, 2.21383512e+00,
        2.65455727e-02, 0.00000000e+00]])
### use @njit
@njit
def entrymat(XX):
    nrow, ncol = XX.shape
    entrydist = np.zeros((ncol, nrow, nrow))
    for d in range(ncol):
        entrydist[d] = np.power(XX[:, d].T, 2) + np.power(XX[:, d], 2) - 2 * XX[:, d] * XX[:, d].T
    return entrydist
entrymat(XX)[0]
Out[237]: 
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

如果修改功能并检查细节

@njit
def entrymat(XX):
    nrow, ncol = XX.shape
    entrydist = np.zeros((ncol, nrow, nrow))
    for d in range(ncol):
        entrydist[d] = np.power(XX[:, d].T, 2) + np.power(XX[:, d], 2) 
    return entrydist
entrymat(XX)[0]
Out[239]: 
array([[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
       [0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
        4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256]])

结果并不像预期的那样,因为它应该是对称的。但是,显示的显示仅在同一方向上添加2列或2行。我还将矩阵XX更改为np.array,但仍然失败。有人能帮我解决这个问题吗?多谢各位


Tags: infordefnpzerosrangearraypower