加快Python中求和的集成

2024-10-02 14:17:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在Python中加速一个特定的(数值)积分。我在Mathematica进行了评估,需要14秒。在python中,它需要15.6分钟

我要计算的积分的形式如下:

enter image description here

python代码如下所示:

from mpmath import hermite

def light_nm( dipol, n, m, t):
    mat_elem = light_amp(n)*light_amp_conj(m)*coef_ground( dipol, n,t)*np.conj(coef_ground( dipol, m,t)) +  \
              light_amp(n+1)*light_amp_conj(m+1)*coef_excit( dipol, n+1,t)*np.conj(coef_excit( dipol, m+1,t))
    return mat_elem


def light_nm_dmu( dipol, n, m, t):
    mat_elem = light_amp(n)*light_amp_conj(m)*(coef_ground_dmu( dipol, n,t)*conj(coef_ground( dipol, m,t)) + coef_ground( dipol, n,t)*conj(coef_ground_dmu( dipol, m,t)) )+    \
            light_amp(n+1)*light_amp_conj(m+1)*(coef_excit_dmu( dipol, n+1,t)*np.conj(coef_excit( dipol, m+1,t)) + coef_excit( dipol, n+1,t)*conj(coef_excit_dmu( dipol, m+1,t)))
    return mat_elem

def prob(dipol, t, x, thlo, cutoff, n, m):
      temp = complex( light_nm(dipol, n, m, t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
                             hermite(n,x)*hermite(m,x)/math.sqrt(2**(n+m)*math.factorial(m)*math.factorial(n)*math.pi))
      return np.real(temp)

def derprob(dipol, t, x, thlo, cutoff, n, m):
      temp = complex( light_nm_dmu(dipol, n, m, t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
                              hermite(n,x)*hermite(m,x)/math.sqrt(2**(n+m)*math.factorial(m)*math.factorial(n)*math.pi))
      if np.imag(temp)>10**(-6):
          print(t)
      return np.real(temp)

def integrand(dipol, t, thlo, cutoff,x):
    return  1/np.sum(np.array([ prob(dipol,t,x,thlo,cutoff,n,m) for n,m in product(range(cutoff),range(cutoff))]))*\
         np.sum(np.array([ derprob(dipol,t,x,thlo,cutoff,n,m) for n,m in product(range(cutoff),range(cutoff))]))**2

def cfi(dipol, t, thlo, cutoff, a):
    global alpha
    alpha = a
    
    temp_func_real = lambda x: np.real(integrand(dipol,t, thlo, cutoff, x))
    temp_real = integ.quad(temp_func_real, -8, 8)
    return  temp_real[0]

hermite函数是从mpmath库调用的。 有没有办法让这段代码运行得更快

谢谢大家!

更新: 我添加了整个代码。(很抱歉耽搁了) “light_nm_dmu”的功能与“light_nm”类似。 我尝试了这个答案,但在light_amp函数中出现了一个错误“TypeError:只有size-1数组才能转换为Python标量”,因此我对prob和derprob进行了矢量化

相同评估的新时间为886.7085871696472=14.8分钟(cfi(0.1,1,0,40,1))


Tags: defnpmathrealtemphermitelightamp
1条回答
网友
1楼 · 发布于 2024-10-02 14:17:41

建议使用:

  1. 矢量化numpy - evaluate function on a grid of points

  2. 使用缓存来加速一大组数字上的阶乘计算,即Is math.factorial memorized?(由Domenico De Felice修改答案)

更新代码

# use cached factorial function
def prob(dipol, t, x, thlo, cutoff, n, m):
      temp = complex( light_nm(dipol, n, m, t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
                             hermite(n,x)*hermite(m,x)/math.sqrt(2**(n+m)*factorial(m)*factorial(n)*math.pi))
      return np.real(temp)

# Vectorize computation
def integrand(dipol, t, thlo, cutoff,x):
    xaxis = np.arange(0, cutoff)
    yaxis = np.arange(0, cutoff)

    return  1/np.sum(prob(dipol,t,x,thlo,cutoff,xaxis[:, None] , yaxis[None, :]))*\
         np.sum(derprob(dipol,t,x,thlo,cutoff,xaxis[:, None] , yaxis[None, :]))**2

# unchanged
def cfi(dipol, t, thlo, cutoff, a):
    global alpha
    alpha = a
    
    temp_func_real = lambda x: np.real(integrand(dipol,t, thlo, cutoff, x))
    temp_real = integ.quad(temp_func_real, -8, 8)
    return  temp_real[0]

# Cached factorial
def factorial(num, fact_memory = {0: 1, 1: 1, 'max': 1}):
    ' Cached factorial since we're computing on lots of numbers '
    # Factorial is defined only for non-negative numbers
    assert num >= 0

    if num <= fact_memory['max']:
        return fact_memory[num]

    for x in range(fact_memory['max']+1, num+1):
        fact_memory[x] = fact_memory[x-1] * x
        
    fact_memory['max'] = num
    return fact_memory[num]

相关问题 更多 >

    热门问题