Python:使用预先计算的元素加速大的双和运算

2024-10-01 19:29:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要计算一个双和的形式:

wignersum{ell1}=sum{ell1}sum{ell2}(2*ell1+1)(2*ell2+1)*W{ell1,ell1,ell2}^2*C1(ell1)*C2(ell2)

其中wignersum是由ell索引的数组,而ell、ell1和ell2都从0到ellmax运行。W{ell,ell1,ell2}^2是一组已知的系数,我已经计算过了(称为w3j),存储在一个shape数组(ellmax,ellmax,ellmax)中,作为一个全局变量由这个函数调用(这些系数的计算非常耗时,我发现从numpy文件加载它们会更快)。C1和C2是形状系数(ellmax)的数组

我通过使用双for循环,从每个已有数组中获取适当的元素,并在每次迭代中更新wignersum数组,成功地计算了这个和。我想有更好的方法来矢量化这个问题,以加快计算速度。我考虑过将C1和C2数组制作成与w3j数组形状相同的数组,然后在ell1和ell2轴上使用np.sum之前将这些数组元素相乘。我不确定这是否真的是一个很好的矢量化方法,如果是,如何真正做到这一点

目前的代码是

import numpy as np
ell_max = 400
w3j = np.ones((ell_max, ell_max, ell_max))
C1 = np.arange(ell_max)
C2 = np.arange(ell_max)

def function(ell_max)
ells = np.arange(ell_max)
wignersum = np.zeros(ell_max)

factor = np.array([2*i+1 for i in range(384)])

for ell1 in ells:
    A = factor[ell1]
    B = C1[ell1]
    for ell2 in ells:
        D = factor[ell2] * C2[ell2] * w3j[:,ell1,ell2]
        wignersum += A * B * D
return wignersum

(请注意,实际上C1C2不是全局变量,而是局部变量,必须从提供给function的一组参数进行计算。但这不是代码速度的限制因素)

对于双for循环,运行ellu max~400需要约1.5秒,这对于我使用它的目的来说太长了。我想矢量化这尽可能提高速度


Tags: fornp数组矢量化maxsumc2系数
1条回答
网友
1楼 · 发布于 2024-10-01 19:29:57

您可以使用einsum或矩阵乘法进行~20倍的加速:

import numpy as np
ell_max = 400
w3j = np.random.randint(1,10,(ell_max, ell_max, ell_max))
C1 = np.random.randint(1,10,ell_max)
C2 = np.random.randint(1,10,ell_max)

def function(ell_max):
    ells = np.arange(ell_max)
    wignersum = np.zeros(ell_max)

    factor = np.array([2*i+1 for i in range(ell_max)])

    for ell1 in ells:
        A = factor[ell1]
        B = C1[ell1]
        for ell2 in ells:
            D = factor[ell2] * C2[ell2] * w3j[:,ell1,ell2]
            wignersum += A * B * D
    return wignersum

def pp_es(l_mx):
    l = np.arange(l_mx)
    f = 2*l+1
    return np.einsum("i,i,j,j,kij",f,C1,f,C2,w3j,optimize=True)

def pp_mm(l_mx):
    l = np.arange(l_mx)
    f = 2*l+1
    return w3j.reshape(l_mx,-1)@np.outer(f*C1,f*C2).ravel()

from timeit import timeit

print(timeit(lambda:pp_es(400),number=10))
print(timeit(lambda:pp_mm(400),number=10))
print(timeit(lambda:function(400),number=10))

print((pp_mm(400)==pp_es(400)).all())
print((function(400)==pp_mm(400)).all())

运行示例:

0.6061844169162214 # einsum
0.6111843499820679 # matrix x vector
12.233918005018495 # OP
True # einsum == matrix x vector
True # OP == matrix x vector

相关问题 更多 >

    热门问题