擅长:python、mysql、java
<p>我可能会计算<em>M=v<sup>T</sup>v</em>,然后展平该矩阵的较低或较高三角形部分</p>
<pre><code>def pairwise_products(v: np.ndarray):
assert len(v.shape) == 1
n = v.shape[0]
m = v.reshape(n, 1) @ v.reshape(1, n)
return m[np.tril_indices_from(m)].ravel()
</code></pre>
<p>我还想提到<a href="http://numba.pydata.org/" rel="nofollow noreferrer">^{<cd1>}</a>,这将使您的“幼稚”方法很可能比此方法更快</p>
<pre><code>import numba
@numba.njit
def pairwise_products_numba(vec: np.ndarray):
k, size = 0, vec.size
output = np.empty(size * (size + 1) // 2)
for i in range(size):
for j in range(i, size):
output[k] = vec[i] * vec[j]
k += 1
return output
</code></pre>
<p>仅测试上述<code>pairwise_products(np.arange(5000))</code>需要约0.3秒,而numba版本需要约0.05秒(忽略用于及时编译函数的第一次运行)</p>