<p>这里有一种方法使用<a href="https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html" rel="noreferrer">^{<cd1>}</a>-</p>
<pre><code>from scipy.spatial.distance import cdist
def closest_rows(a):
# Get euclidean distances as 2D array
dists = cdist(a, a, 'sqeuclidean')
# Fill diagonals with something greater than all elements as we intend
# to get argmin indices later on and then index into input array with those
# indices to get the closest rows
dists.ravel()[::dists.shape[1]+1] = dists.max()+1
return a[dists.argmin(1)]
</code></pre>
<p>样本运行-</p>
<pre><code>In [72]: a
Out[72]:
array([[1, 2, 8],
[7, 4, 2],
[9, 1, 7],
[0, 1, 5],
[6, 4, 3]])
In [73]: closest_rows(a)
Out[73]:
array([[0, 1, 5],
[6, 4, 3],
[6, 4, 3],
[1, 2, 8],
[7, 4, 2]])
</code></pre>
<p><strong>运行时测试</strong></p>
<p>其他工作方法-</p>
<pre><code>def norm_app(a): # @Psidom's soln
dist = np.linalg.norm(a - a[:,None], axis=-1);
dist[np.arange(dist.shape[0]), np.arange(dist.shape[0])] = np.nan
return a[np.nanargmin(dist, axis=0)]
</code></pre>
<p>带<code>10,000</code>点的计时-</p>
<pre><code>In [79]: a = np.random.randint(0,9,(10000,3))
In [80]: %timeit norm_app(a) # @Psidom's soln
1 loop, best of 3: 3.83 s per loop
In [81]: %timeit closest_rows(a)
1 loop, best of 3: 392 ms per loop
</code></pre>
<hr/>
<p><strong>进一步提升性能</strong></p>
<p>有一个<a href="https://github.com/droyed/eucl_dist" rel="noreferrer">^{<cd3>}</a>包(免责声明:我是它的作者),其中包含各种计算欧几里德距离的方法,这些方法比<code>SciPy's cdist</code>更有效,特别是对于大型数组。</p>
<p>因此,利用它,我们会有一个更具表现力的,像这样-</p>
<pre><code>from eucl_dist.cpu_dist import dist
def closest_rows_v2(a):
dists = dist(a,a, matmul="gemm", method="ext")
dists.ravel()[::dists.shape[1]+1] = dists.max()+1
return a[dists.argmin(1)]
</code></pre>
<p>时间安排-</p>
<pre><code>In [162]: a = np.random.randint(0,9,(10000,3))
In [163]: %timeit closest_rows(a)
1 loop, best of 3: 394 ms per loop
In [164]: %timeit closest_rows_v2(a)
1 loop, best of 3: 229 ms per loop
</code></pre>