计算lis中成对点积的Pythonic方法

2024-09-30 04:30:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个由元组的所有组合组成的列表,每个元素只能是-1或1。列表可以生成为:

N=2
list0 = [p for p in itertools.product([-1, 1], repeat=N)]

例如,如果元组有N=2个元素:

list0 = [(-1, -1), (-1, 1), (1, -1), (1, 1)]

因此,元组的总数是2^2=4。你知道吗

如果元组有N=3个元素:

list0 = [(-1, -1, -1), (-1, -1, 1), (-1, 1, -1), (-1, 1, 1), (1, -1, -1), (1, -1, 1), (1, 1, -1), (1, 1, 1)]

我关心的是:

现在我想得到列表中任意一对元组(包括元组本身的元组)之间的点积的所有结果。所以对于N=2会有6(pairs) + 4(itself) = 10 combinations;对于N=3会有28(pairs) + 8(itself) = 36 combinations.

对于小型N,我可以做如下操作:

for x in list0:
    for y in list0:
        print(np.dot(x,y)) 

但是,假设我已经有了list0,如果N很大,比如~50,那么计算所有点积可能性的最佳方法是什么?你知道吗


Tags: in元素列表forproduct元组repeatprint
2条回答

你可以坚持用numpy

import numpy as np
import random


vals = []
num_vecs = 3
dimension = 4
for n in range(num_vecs):
    val = []
    for _ in range(dimension):
        val.append(random.random())
    vals.append(val)

# make into numpy array
vals = np.stack(vals)
print(vals.shape == (num_vecs, dimension))

# multiply every vector with every other using broadcastin
every_with_every_mult = vals[:, None] * vals[None, :]
print(every_with_every_mult.shape == (num_vecs, num_vecs, dimension))

# sum the final dimension
every_with_every_dot = np.sum(every_with_every_mult, axis=every_with_every_mult.ndim - 1)
print(every_with_every_dot.shape == (num_vecs, num_vecs))

# check it works
for i in range(num_vecs):
    for j in range(num_vecs):
        assert every_with_every_dot[i,j] == np.sum(vals[i]*vals[j])

您可以使用np.dot本身:

import numpy as np

list0 = [(-1, -1, -1), (-1, -1, 1), (-1, 1, -1), (-1, 1, 1), (1, -1, -1), (1, -1, 1), (1, 1, -1), (1, 1, 1)]

# approach using np.dot
a = np.array(list0)
result = np.dot(a, a.T)

# brute force approach
brute = []
for x in list0:
    brute.append([np.dot(x, y) for y in list0])
brute = np.array(brute)

print((brute == result).all())

输出

True

你要问的是a与自身的矩阵乘法,从documentation

if both a and b are 2-D arrays, it is matrix multiplication,

请注意,最具pythonic解决方案是使用操作符@

import numpy as np

list0 = [(-1, -1, -1), (-1, -1, 1), (-1, 1, -1), (-1, 1, 1), (1, -1, -1), (1, -1, 1), (1, 1, -1), (1, 1, 1)]

# approach using np.dot
a = np.array(list0)
result = a @ a.T

# brute force approach
brute = []
for x in list0:
    brute.append([np.dot(x, y) for y in list0])
brute = np.array(brute)

print((brute == result).all())

输出

True

注意:代码是在Python3.5中运行的

相关问题 更多 >

    热门问题