平均张量

2024-10-02 22:25:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我还有一个问题和我的上一个问题(Python tensor product)有关。在那里我发现我的计算有错误。与名词短语tensordot我正在计算以下等式: enter image description here <;。>;应显示平均值。 在python代码中,它确实是这样的(ewp是一个向量,re是一个张量):

q1 = numpy.tensordot(re, ewp, axes=(1, 0))
q2 = numpy.tensordot(q1, ewp, axes=(1, 0))
serc = q2 ** 2

或者

serc = numpy.einsum('im, m -> i', numpy.einsum('ilm, l -> im',
numpy.einsum('iklm, k -> ilm', numpy.einsum('ijklm, j -> iklm',
numpy.einsum('ijk, ilm -> ijklm', re, re), ewp), ewp), ewp), ewp)

现在在这两个python代码中我都忽略了,所有的可能性都是成倍增加的。当然w_jw_k对于j=k来说并不是独立的。在这种情况下,只有j和k是相同的,我们得到< w_j*w_j*w_l*w_m> = <w_j>*<w_l>*<w_m>。对于j=k=l,我们得到:< w_j*w_j*w_j*w_m> = <w_j>*<w_m>。对于j=k=l=m< w_j*w_j*w_j*w_j> = <w_j>。只有当所有变量都不同时,独立性才是真的,我们得到:< w_i*w_j*w_l*w_m> = <w_i>*<w_j>*<w_l>*<w_m>。这就是代码对所有可能性的作用。我希望这能让我的问题可以理解。现在我的问题是如何在我的代码中表示这一点?你知道吗

编辑:我的想法是首先创建一个4dim。表示<w_j w_k w_l w_m>的张量:

wtensor = numpy.einsum('jkl, m -> jklm', numpy.einsum('jk, l -> jkl',
numpy.einsum('j, k -> jk', ewp, ewp), ewp), ewp)

然后我需要改变那些不独立的值。我想他们应该在对角线上?但我真的不太懂张量微积分,所以在这一点上我很挣扎。 在操纵w张量之后,我将通过执行以下操作得到结果:

serc = numpy.einsum('ijklm, jklm -> i', numpy.einsum('ijk, ilm ->
ijklm', re, re), wtensor)

Edit2:在另一篇文章中,我精确地问了如何操作4dim,使它适合这里。Divakar有一个非常好的解决方案,可以在这里看到:Fill a multidimensional array efficiently that have many if else statements

from itertools import product

n_dims = 4 # Number of dims
# Create 2D array of all possible combinations of X's as rows
idx = np.sort(np.array(list(product(np.arange(gn),
repeat=n_dims))),axis=1)
# Get all X's indexed values from ewp array
vals = ewp[idx]
# Set the duplicates along each row as 1s. With the np.prod coming up
next,
#these 1s would not affect the result, which is the expected pattern
here.
vals[:,1:][idx[:,1:] == idx[:,:-1]] = 1
# Perform product along each row and reshape into multi-dim array
out = vals.prod(1).reshape([gn]*n_dims)

我在这里得到的数组是wtensor,我现在可以在上面的代码中使用它:

serc = numpy.einsum('ijklm, jklm -> i', numpy.einsum('ijk, ilm ->
ijklm', re, re), wtensor)

这最终给了我想要的结果,基本上回答了问题。 尽管有一个问题。ewp的长度也定义了张量的大小,不应该大于6。否则代码将占用大量内存。我的目的是使用它,直到一个大小8,所以不幸的是,现在我的下一个问题。你知道吗


Tags: the代码renumpynpproductarrayidx