<p><strong>方法#1</strong></p>
<p>最简单的带<a href="https://numpy.org/doc/stable/reference/generated/numpy.triu_indices.html" rel="nofollow noreferrer">^{<cd1>}</a>-</p>
<pre><code>In [45]: a
Out[45]:
array([[1, 2, 3, 4],
[4, 5, 6, 7],
[7, 8, 9, 1]])
In [46]: r,c = np.triu_indices(a.shape[1],1)
In [47]: a[:,c]*a[:,r]
Out[47]:
array([[ 2, 3, 4, 6, 8, 12],
[20, 24, 28, 30, 35, 42],
[56, 63, 7, 72, 8, 9]])
</code></pre>
<p><strong>方法#2</strong></p>
<p>用于大型阵列的高效内存-</p>
<pre><code>m,n = a.shape
s = np.r_[0,np.arange(n-1,-1,-1).cumsum()]
out = np.empty((m, n*(n-1)//2), dtype=a.dtype)
for i,(s0,s1) in enumerate(zip(s[:-1], s[1:])):
out[:,s0:s1] = a[:,i,None] * a[:,i+1:]
</code></pre>
<p><strong>方法#3</strong></p>
<p>基于掩蔽的一-</p>
<pre><code>m,n = a.shape
mask = ~np.tri(n,dtype=bool)
m3D = np.broadcast_to(mask, (m,n,n))
b1 = np.broadcast_to(a[...,None], (m,n,n))
b2 = np.broadcast_to(a[:,None,:], (m,n,n))
out = (b1[m3D]* b2[m3D]).reshape(m,-1)
</code></pre>
<p><strong>方法#4</strong></p>
<p>将方法#2扩展为<code>numba</code>1-</p>
<pre><code>from numba import njit
def numba_app(a):
m,n = a.shape
out = np.empty((m, n*(n-1)//2), dtype=a.dtype)
return numba_func(a,out,m,n)
@njit
def numba_func(a,out,m,n):
for p in range(m):
I = 0
for i in range(n):
for j in range(i+1,n):
out[p,I] = a[p,i] * a[p,j]
I += 1
return out
</code></pre>
<p>然后,利用<code>parallel</code>处理(正如@max9111在评论中指出的那样),如下-</p>
<pre><code>from numba import prange
def numba_app_parallel(a):
m,n = a.shape
out = np.empty((m, n*(n-1)//2), dtype=a.dtype)
return numba_func_parallel(a,out,m,n)
@njit(parallel=True)
def numba_func_parallel(a,out,m,n):
for p in prange(m):
I = 0
for i in range(n):
for j in range(i+1,n):
out[p,I] = a[p,i] * a[p,j]
I += 1
return out
</code></pre>
<h3>基准测试</h3>
<p>使用<a href="https://github.com/droyed/benchit" rel="nofollow noreferrer">^{<cd4>}</a>包(打包在一起的一些基准测试工具;免责声明:我是它的作者)对建议的解决方案进行基准测试</p>
<pre><code>import benchit
in_ = [np.random.rand(5000, 80), np.random.rand(10000, 100), np.random.rand(20000, 120)]
funcs = [ehsan, app1, app2, app3, numba_app, numba_app_parallel]
t = benchit.timings(funcs, in_, indexby='shape')
t.rank()
t.plot(logx=False, save='timings.png')
</code></pre>
<p><a href="https://i.stack.imgur.com/185gu.png" rel="nofollow noreferrer"><img src="https://i.stack.imgur.com/185gu.png" alt="enter image description here"/></a></p>
<p>结论:<code>Numba</code>的人似乎做得很好,而<code>app2</code>的人则是裸体人</p>