在numpy python中,如何仅使用矩阵运算计算两个矩阵之间的欧几里德距离(循环不使用)?

2024-06-26 13:29:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在numpy python中仅使用矩阵运算计算两个矩阵之间的欧几里德距离,但不使用任何for循环

如果我只需要为两个单一向量计算这个值,那就很简单了,因为我只需要使用欧几里德距离的公式:

D(x, y) = ∥y – x∥ = √ ( xT x + yT y – 2 xT y )

把它转换成numpy就像做这样的事情一样简单

a@a.T - 2*a@b.T + b@b.T

然而,在不使用for循环的情况下,对大小为PxQ和RxQ的两个矩阵执行此操作被证明是很棘手的

如果我有一个叫做a的矩阵,给我一个更好的例子来说明我要做什么:

[[ 4  9  9]
 [11  8  1]
 [ 2  6  4]
 [ 4  7 11]
 [ 6  7  9]]

和一个称为B的矩阵:

[[ 5  8  8]
 [10  5  1]
 [ 6  6  9]
 [ 2  1  2]
 [ 9  1  3]
 [ 6  1  6]
 [ 9 10  3]
 [10  4  8]]

然后我希望能够计算出结果矩阵C:

[[ 1.73205081 10.77032961  3.60555128 10.81665383 11.18033989  8.77496439
   7.87400787  7.87400787]
 [ 9.21954446  3.16227766  9.64365076 11.44552314  7.54983444  9.94987437
   3.46410162  8.1240384 ]
 [ 5.38516481  8.60232527  6.40312424  5.38516481  8.66025404  6.70820393
   8.1240384   9.16515139]
 [ 3.31662479 11.83215957  3.         11.         11.18033989  8.06225775
   9.89949494  7.34846923]
 [ 1.73205081  9.16515139  1.         10.04987562  9.          6.70820393
   7.34846923  5.09901951]]

我相信这个线程solution的堆栈溢出顶级解决方案正是这样做的,但它是在matlab中,我很难将它转换为Python numpy解决方案


Tags: numpy证明距离for情况矩阵解决方案事情
2条回答

试试numpy broadcasting

dist_mat = np.sum((a[:,None] - b)**2, axis=-1)**.5

输出:

array([[ 1.73205081, 10.77032961,  3.60555128, 10.81665383, 11.18033989,
         8.77496439,  7.87400787,  7.87400787],
       [ 9.21954446,  3.16227766,  9.64365076, 11.44552314,  7.54983444,
         9.94987437,  3.46410162,  8.1240384 ],
       [ 5.38516481,  8.60232527,  6.40312424,  5.38516481,  8.66025404,
         6.70820393,  8.1240384 ,  9.16515139],
       [ 3.31662479, 11.83215957,  3.        , 11.        , 11.18033989,
         8.06225775,  9.89949494,  7.34846923],
       [ 1.73205081,  9.16515139,  1.        , 10.04987562,  9.        ,
         6.70820393,  7.34846923,  5.09901951]])

在最后一个轴上使用np.linalg.norm

>>> np.linalg.norm((a[:,None] - b), axis=-1)
array([[ 1.73205081, 10.77032961,  3.60555128, 10.81665383, 11.18033989,
         8.77496439,  7.87400787,  7.87400787],
       [ 9.21954446,  3.16227766,  9.64365076, 11.44552314,  7.54983444,
         9.94987437,  3.46410162,  8.1240384 ],
       [ 5.38516481,  8.60232527,  6.40312424,  5.38516481,  8.66025404,
         6.70820393,  8.1240384 ,  9.16515139],
       [ 3.31662479, 11.83215957,  3.        , 11.        , 11.18033989,
         8.06225775,  9.89949494,  7.34846923],
       [ 1.73205081,  9.16515139,  1.        , 10.04987562,  9.        ,
         6.70820393,  7.34846923,  5.09901951]])

相关问题 更多 >