该矩阵计算是否可以在没有中间3D矩阵的情况下实现或近似?

2024-10-01 07:47:00 发布

您现在位置:Python中文网/ 问答频道 /正文

给定一个NxN矩阵W,我想计算一个NxN矩阵C,由这个链接中的方程给出:https://i.stack.imgur.com/dY7rY.png,或者在LaTeX中

$$C_{ij} = \max_k \bigg\{ \sum_l \bigg( W_{ik}W_{kl}W_{lj} - W_{ik}W_{kj} \bigg)  \bigg\}.$$

我曾尝试在PyTorch中实现这一点,但我要么通过构建中间NxN 3D矩阵(对于较大的N,这会导致我的GPU内存不足)遇到内存问题,要么使用了速度非常慢的for循环。我想不出怎样才能绕过这些。如果没有这样一个大的中间矩阵,我该如何实现这个计算,或者它的近似值

任何语言的建议、伪代码或Python/Numpy/PyTorch的任何实现都将不胜感激


Tags: httpscompngstack链接矩阵pytorchik
2条回答

该公式可简化为:

C_ij = max_k ( W_ik M_kj)

在哪里

M = W * W - N * W

N矩阵的大小WW * W通常的乘积

然后,在上面的公式中,对于每一个ij都有一个独立的最大值需要计算。如果不知道W的进一步性质,通常不可能进一步简化问题。因此,在计算矩阵M之后,可以在ij上进行循环,并计算最大值

使用Numba的第一个解决方案(您可以使用Cython或普通C做同样的事情)是使用简单的循环来描述问题

import numpy as np
import numba as nb

@nb.njit(fastmath=True,parallel=True)
def calc_1(W):
    C=np.empty_like(W)
    N=W.shape[0]

    for i in nb.prange(N):
        TMP=np.empty(N,dtype=W.dtype)
        for j in range(N):
            for k in range(N):
                acc=0
                for l in range(N):
                    acc+=W[i,k]*W[k,l]*W[l,j]-W[i,k]*W[k,j]
                TMP[k]=acc
            C[i,j]=np.max(TMP)
    return C

Francesco提供了一种简化方法,该方法可以更好地扩展较大的阵列大小。这导致了下面的内容,我还优化了一个小的临时数组

@nb.njit(fastmath=True,parallel=True)
def calc_2(W):
    C=np.empty_like(W)
    N=W.shape[0]
    M = np.dot(W,W) - N * W

    for i in nb.prange(N):
        for j in range(N):
            val=W[i,0]*M[0,j]
            for k in range(1,N):
                TMP=W[i,k]*M[k,j]
                if TMP>val:
                    val=TMP
            C[i,j]=val
    return C

这可以通过部分循环展开和优化阵列访问来进一步优化。有些编译器可能会自动执行此操作

@nb.njit(fastmath=True,parallel=True)
def calc_3(W):
    C=np.empty_like(W)
    N=W.shape[0]
    W=np.ascontiguousarray(W)
    M = np.dot(W.T,W.T) - W.shape[0] * W.T

    for i in nb.prange(N//4):
        for j in range(N):
            val_1=W[i*4+0,0]*M[j,0]
            val_2=W[i*4+1,0]*M[j,0]
            val_3=W[i*4+2,0]*M[j,0]
            val_4=W[i*4+3,0]*M[j,0]
            for k in range(1,N):
                TMP_1=W[i*4+0,k]*M[j,k]
                TMP_2=W[i*4+1,k]*M[j,k]
                TMP_3=W[i*4+2,k]*M[j,k]
                TMP_4=W[i*4+3,k]*M[j,k]
                if TMP_1>val_1:
                    val_1=TMP_1
                if TMP_2>val_2:
                    val_2=TMP_2
                if TMP_3>val_3:
                    val_3=TMP_3
                if TMP_4>val_4:
                    val_4=TMP_4

            C[i*4+0,j]=val_1
            C[i*4+1,j]=val_2
            C[i*4+2,j]=val_3
            C[i*4+3,j]=val_4

    #Remainder
    for i in range(N//4*4,N):
        for j in range(N):
            val=W[i,0]*M[j,0]
            for k in range(1,N):
                TMP=W[i,k]*M[j,k]
                if TMP>val:
                    val=TMP
            C[i,j]=val
    return C

计时

W=np.random.rand(100,100)
%timeit calc_1(W)
#16.8 ms ± 131 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit calc_2(W)
#449 µs ± 25.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit calc_3(W)
#259 µs ± 47.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

W=np.random.rand(2000,2000)
#Temporary array would be 64GB in this case
%timeit calc_2(W)
#5.37 s ± 174 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit calc_3(W)
#596 ms ± 30.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

相关问题 更多 >