如何以矢量形式编写这个numpy代码?

2024-05-19 23:02:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我在python中有下面的函数,我不知道如何用矢量形式表示。 对我来说,cov是形状(2,2)的numpy数组,mu是形状(2,2)的平均向量,xtp是形状(~50000,2)。 我知道scipy提供scipy.stats.u正常但我正在努力学习如何编写高效的矢量化代码。求你了

def mvnpdf(xtp, mu, cov):
    temp = np.zeros(xtp.shape[0])
    i = 0
    length = xtp.shape[0]
    const = 1 / ( ((2* np.pi)**(len(mu)/2)) * (np.linalg.det(cov)**(1/2)) )
    inv = np.linalg.inv(cov)
    while i < length:
        x = xtp[i]-mu
        exponent = (-1/2) * (x.dot(inv).dot(x))
        temp[i] =  (const * np.exp(exponent))
        i+=1
    return temp

Tags: 函数npscipycovlengthtempdot形状
1条回答
网友
1楼 · 发布于 2024-05-19 23:02:07

矢量化唯一棘手的部分是double.dot。让我们把它隔离开来:

x = xtp - mu  # move this out of the loop
ddot = [i.dot(inv).dot(i) for i in x]
temp = const * np.exp(-0.5 * ddot)

把它放到你的代码里,看看它是否产生同样的结果。你知道吗

有几种“矢量化”adot的方法。我想先试试的是einsum。在我的测试中,这相当于:

ddot = np.einsum('ij,jk,ik->i',x,inv,x)

我建议尝试一下,看看它是否能起作用,加快速度。在交互式shell中使用较小的数组(而不是~50000)进行这些计算。你知道吗

我在测试东西

In [225]: x
Out[225]: 
array([[  0.,   2.],
       [  1.,   3.],
       ...
       [  7.,   9.],
       [  8.,  10.],
       [  9.,  11.]])
In [226]: inv
Out[226]: 
array([[ 1.,  0.],
       [ 0.,  1.]])

由于这是一个学习练习,我将把细节留给你。你知道吗

使用(2,2),一个cov的计算可能比使用detinv函数更快。但是length迭代是时间消费者。你知道吗

相关问题 更多 >