二维元素的Numpy矩阵乘法

2024-09-28 20:52:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个anumpy ndarray3x3矩阵,看起来像这样

a =  ([[ uu, uv, uw],
       [ uv, vv, vw],
       [ uw, vw, ww]])

每个组件本身就是一个大小为(N,M)的二维数组,因此a矩阵具有(3,3,N,M)形状。你知道吗

我怎样才能以pythonic的方式执行a*a的矩阵乘法呢? 使用a@a抛出以下错误(对于N=1218和M=540):

ValueError: shapes (3,3,1218,540) and (3,3,1218,540) not aligned: 540 (dim 3) != 1218 (dim 2)

我希望能够像执行a的元素一样执行此操作,其中只有标量值a@a不会抛出与其形状相关的错误,因为这是一个简单的3x3矩阵乘法。你知道吗

谢谢。


Tags: 错误组件矩阵数组uvuw形状vw
1条回答
网友
1楼 · 发布于 2024-09-28 20:52:47

假设您希望沿最后两个轴对每个元素执行矩阵乘法,我们可以使用^{}-

np.einsum('ijkl,jmkl->imkl',a,a)

样品运行验证-

In [43]: np.random.seed(0)

In [44]: a = np.random.rand(3,3,4,5)

In [45]: a[:,:,0,0].dot(a[:,:,0,0])
Out[45]: 
array([[0.71750146, 1.17057872, 1.11135764],
       [0.62938365, 0.86437796, 0.74541383],
       [1.04636618, 1.62011127, 1.35483565]])

In [46]: np.einsum('ijkl,jmkl->imkl',a,a)[:,:,0,0]
Out[46]: 
array([[0.71750146, 1.17057872, 1.11135764],
       [0.62938365, 0.86437796, 0.74541383],
       [1.04636618, 1.62011127, 1.35483565]])

相关问题 更多 >