在处理矩阵中每列的成对距离时,我尝试使用numba来提高计算效率。简化代码显示如下
### toy example
XX = np.mat(np.random.normal(size = (10, 6)))
### not use njit
def entrymat(XX):
nrow, ncol = XX.shape
entrydist = np.zeros((ncol, nrow, nrow))
for d in range(ncol):
entrydist[d] = np.power(XX[:, d].T, 2) + np.power(XX[:, d], 2) - 2 * XX[:, d] * XX[:, d].T
return entrydist
entrymat(XX)[0]
Out[235]:
array([[0.00000000e+00, 2.58320218e+00, 1.55410224e-01, 1.78789065e-02,
1.04846099e-01, 1.77814001e+00, 1.29929461e+00, 2.47598273e-01,
1.32994371e+00, 9.80701480e-01],
[2.58320218e+00, 0.00000000e+00, 1.47140125e+00, 2.17126797e+00,
1.64720502e+00, 7.49473659e-02, 2.18433107e-01, 1.23130504e+00,
7.62017353e+00, 6.74720397e+00],
[1.55410224e-01, 1.47140125e+00, 0.00000000e+00, 6.78649425e-02,
4.95919632e-03, 8.82187044e-01, 5.55986489e-01, 1.06856550e-02,
2.39461044e+00, 1.91690883e+00],
[1.78789065e-02, 2.17126797e+00, 6.78649425e-02, 0.00000000e+00,
3.61332368e-02, 1.43941718e+00, 1.01234592e+00, 1.32408981e-01,
1.65622455e+00, 1.26341143e+00],
[1.04846099e-01, 1.64720502e+00, 4.95919632e-03, 3.61332368e-02,
0.00000000e+00, 1.01943288e+00, 6.65964658e-01, 3.02040080e-02,
2.18162154e+00, 1.72686723e+00],
[1.77814001e+00, 7.49473659e-02, 8.82187044e-01, 1.43941718e+00,
1.01943288e+00, 0.00000000e+00, 3.74821648e-02, 6.98689833e-01,
6.18368193e+00, 5.39992046e+00],
[1.29929461e+00, 2.18433107e-01, 5.55986489e-01, 1.01234592e+00,
6.65964658e-01, 3.74821648e-02, 0.00000000e+00, 4.12515343e-01,
5.25829799e+00, 4.53762330e+00],
[2.47598273e-01, 1.23130504e+00, 1.06856550e-02, 1.32408981e-01,
3.02040080e-02, 6.98689833e-01, 4.12515343e-01, 0.00000000e+00,
2.72522097e+00, 2.21383512e+00],
[1.32994371e+00, 7.62017353e+00, 2.39461044e+00, 1.65622455e+00,
2.18162154e+00, 6.18368193e+00, 5.25829799e+00, 2.72522097e+00,
0.00000000e+00, 2.65455727e-02],
[9.80701480e-01, 6.74720397e+00, 1.91690883e+00, 1.26341143e+00,
1.72686723e+00, 5.39992046e+00, 4.53762330e+00, 2.21383512e+00,
2.65455727e-02, 0.00000000e+00]])
### use @njit
@njit
def entrymat(XX):
nrow, ncol = XX.shape
entrydist = np.zeros((ncol, nrow, nrow))
for d in range(ncol):
entrydist[d] = np.power(XX[:, d].T, 2) + np.power(XX[:, d], 2) - 2 * XX[:, d] * XX[:, d].T
return entrydist
entrymat(XX)[0]
Out[237]:
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
如果修改功能并检查细节
@njit
def entrymat(XX):
nrow, ncol = XX.shape
entrydist = np.zeros((ncol, nrow, nrow))
for d in range(ncol):
entrydist[d] = np.power(XX[:, d].T, 2) + np.power(XX[:, d], 2)
return entrydist
entrymat(XX)[0]
Out[239]:
array([[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256],
[0.08823641, 6.60499623, 0.73027086, 0.2363355 , 0.56997629,
4.76486168, 3.64451071, 1.00149689, 1.77920915, 1.21761256]])
结果并不像预期的那样,因为它应该是对称的。但是,显示的显示仅在同一方向上添加2列或2行。我还将矩阵XX更改为np.array,但仍然失败。有人能帮我解决这个问题吗?多谢各位
目前没有回答
相关问题 更多 >
编程相关推荐