<p>对于任意维数的张量(即包含(…,N,d)向量),这将起到作用。注意,它不是在集合之间(例如,不像<code>scipy.spatial.distance.cdist</code>),而是在一批向量中(比如<code>scipy.spatial.distance.pdist</code>)</p>
<pre><code>import tensorflow as tf
import string
def pdist(arr):
"""Pairwise Euclidean distances between vectors contained at the back of tensors.
Uses expansion: (x - y)^T (x - y) = x^Tx - 2x^Ty + y^Ty
:param arr: (..., N, d) tensor
:returns: (..., N, N) tensor of pairwise distances between vectors in the second-to-last dim.
:rtype: tf.Tensor
"""
shape = tuple(arr.get_shape().as_list())
rank_ = len(shape)
N, d = shape[-2:]
# Build a prefix from the array without the indices we'll use later.
pref = string.ascii_lowercase[:rank_ - 2]
# Outer product of points (..., N, N)
xxT = tf.einsum('{0}ni,{0}mi->{0}nm'.format(pref), arr, arr)
# Inner product of points. (..., N)
xTx = tf.einsum('{0}ni,{0}ni->{0}n'.format(pref), arr, arr)
# (..., N, N) inner products tiled.
xTx_tile = tf.tile(xTx[..., None], (1,) * (rank_ - 1) + (N,))
# Build the permuter. (sigh, no tf.swapaxes yet)
permute = list(range(rank_))
permute[-2], permute[-1] = permute[-1], permute[-2]
# dists = (x^Tx - 2x^Ty + y^Tx)^(1/2). Note the axis swapping is necessary to 'pair' x^Tx and y^Ty
return tf.sqrt(xTx_tile - 2 * xxT + tf.transpose(xTx_tile, permute))
</code></pre>