如何在numpy中有效地展开因子张量?

2024-09-30 01:26:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我把一个三维张量分解为三个二维矩阵,就像本文中的方程22:http://www.iro.umontreal.ca/~memisevr/pubs/pami_relational.pdf

我的问题是,如果我想显式地计算张量,有没有比numpy更好的方法?你知道吗

W = np.zeros((100,100,100))
for i in range(100):
    for j in range(100):
        for k in range(100):
            W[i,j,k] = np.sum([wxf[i,f]*wyf[j,f]*wzf[k,f] for f in range(100)]) 

Tags: inhttpforwwwnprange矩阵ca
2条回答

您的示例使使用np.einsum()提出解决方案变得非常简单:

W = np.einsum('ij,jf,kf->ijk', wxf, wyf, wzf)

我倾向于用^{}来写这些东西,因为它通常是最容易写的:

def fast(wxf, wyf, wzf):
    return np.einsum('if,jf,kf->ijk', wxf, wyf, wzf)

def slow(wxf, wyf, wzf):
    N = len(wxf)
    W = np.zeros((N, N, N))
    for i in range(N):
        for j in range(N):
            for k in range(N):
                W[i,j,k] = np.sum([wxf[i,f]*wyf[j,f]*wzf[k,f] for f in range(N)]) 
    return W

def gen_ws(N):
    wxf = np.random.random((N,N))
    wyf = np.random.random((N,N))
    wzf = np.random.random((N,N))
    return wxf, wyf, wzf

给予

>>> ws = gen_ws(25)
>>> via_slow = slow(*ws)
>>> via_fast = fast(*ws)
>>> np.allclose(via_slow, via_fast)
True

以及

>>> ws = gen_ws(100)
>>> %timeit fast(*ws)
10 loops, best of 3: 91.6 ms per loop

相关问题 更多 >

    热门问题