cod的Cython优化

2024-09-28 20:42:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我正努力用Cython来提高python粒子跟踪代码的性能。在

下面是我的纯Python代码:

from scipy.integrate import odeint
import numpy as np
from numpy import sqrt, pi, sin, cos
from time import time as Time
import multiprocessing as mp
from functools import partial

cLight = 299792458.
Dim = 6

class Integrator:
    def __init__(self, ring):
        self.ring = ring

    def equations(self, X, s):
        dXds = np.zeros(Dim)

        E, B = self.ring.getEMField( [X[0], X[2], s], X[4] )

        h = 1 + X[0]/self.ring.ringRadius
        p_s = np.sqrt(X[5]**2 - self.ring.particle.mass**2 - X[1]**2 - X[3]**2)
        dtds = h*X[5]/p_s
        gamma = X[5]/self.ring.particle.mass
        beta = np.array( [X[1], X[3], p_s] ) / X[5]

        dXds[0] = dtds*beta[0]
        dXds[2] = dtds*beta[1]
        dXds[1] = p_s/self.ring.ringRadius + self.ring.particle.charge*(dtds*E[0] + dXds[2]*B[2] - h*B[1])
        dXds[3] = self.ring.particle.charge*(dtds*E[1] + h*B[0] - dXds[0]*B[2])
        dXds[4] = dtds
        dXds[5] = self.ring.particle.charge*(dXds[0]*E[0] + dXds[2]*E[1] + h*E[2])
        return dXds

    def odeSolve(self, X0, sRange):
        sol = odeint(self.equations, X0, sRange)
        return sol

class Ring:
    def __init__(self, particle):
        self.particle = particle
        self.ringRadius = 7.112
        self.magicB0 = self.particle.magicMomentum/self.ringRadius

    def getEMField(self, pos, time):
        x, y, s = pos
        theta = (s/self.ringRadius*180/pi) % 360
        r = sqrt(x**2 + y**2)
        arg = 0 if r == 0 else np.angle( complex(x/r, y/r) )
        rn = r/0.045

        k2 = 37*24e3
        k10 = -4*24e3

        E = np.zeros(3)
        B = np.array( [ 0, self.magicB0, 0 ] )

        for i in range(4):
            if ((21.9+90*i < theta < 34.9+90*i or 38.9+90*i < theta < 64.9+90*i) and (-0.05 < x < 0.05 and -0.05 < y < 0.05)):
                E = np.array( [ k2*x/0.045 + k10*rn**9*cos(9*arg), -k2*y/0.045 -k10*rn**9*sin(9*arg), 0] )
                break
        return E, B

class Particle:
    def __init__(self):
        self.mass = 105.65837e6
        self.charge = 1.
        self.gm2 = 0.001165921 

        self.magicMomentum = self.mass/sqrt(self.gm2)
        self.magicEnergy = sqrt(self.magicMomentum**2 + self.mass**2)
        self.magicGamma = self.magicEnergy/self.mass
        self.magicBeta = self.magicMomentum/(self.magicGamma*self.mass)


def runSimulation(nParticles, tEnd):
    particle = Particle()
    ring = Ring(particle)
    integrator = Integrator(ring)

    Xs = np.array( [ np.array( [45e-3*(np.random.rand()-0.5)*2, 0, 0, 0, 0, particle.magicEnergy] ) for i in range(nParticles) ] )
    sRange = np.arange(0, tEnd, 1e-9)*particle.magicBeta*cLight 

    ode = partial(integrator.odeSolve, sRange=sRange)

    t1 = Time()

    pool = mp.Pool()
    sol = np.array(pool.map(ode, Xs))

    t2 = Time()
    print ("%.3f sec" %(t2-t1))

    return t2-t1

显然,最耗时的过程是集成ODE,在class Integrator中定义为odeSolve()和equations()。另外,在求解过程中,类环中的getEMField()方法与equations()方法调用的次数一样多。 我试图使用Cython获得显著的速度提升(至少10倍~20倍),但通过以下Cython脚本,我只获得了~1.5倍的速度提升:

^{pr2}$

我该怎么做才能让Cython得到最大的效果? (我尝试了Numba而不是Cython,实际上Numba的性能提升是巨大的(大约20倍的加速)。但是我很难在python类实例中使用Numba,所以我决定使用Cython而不是Numba)。在

以下是cython对其编译的注释,供参考: enter image description here

enter image description here

enter image description here

enter image description here


Tags: fromimportselfdefnpsqrtarraycython
1条回答
网友
1楼 · 发布于 2024-09-28 20:42:49

这是一个非常不完整的答案,因为我没有分析或计时任何东西,甚至没有检查它是否给出相同的答案。不过,以下是一些减少Cython生成的Python代码的建议:

  • 添加@cython.cdivision(True)编译指令。这意味着在浮点除法时不会引发ZeroDivisionError,而是得到一个NaN值。(只有在不希望引发错误时才执行此操作)。

  • p_s = np.sqrt(...)更改为p_s = sqrt(...)。这将删除只对单个值进行操作的numpy调用。你好像在别的地方做过这件事,所以我不知道你为什么漏掉这条线。

  • 尽可能使用固定大小的C数组而不是numpy数组:

    cdef double beta[3]
    # ...
    beta[0] = X[1]/X[5]
    beta[1] = X[3]/X[5]
    beta[2] = p_s/X[5]
    

    当编译时知道大小(并且相当小)并且不想返回它时,可以这样做。这避免了对np.zeros的调用和随后的一些类型检查,以便为其分配类型化numpy数组。我认为beta是你唯一能做到的地方。

  • np.angle( complex(x/r, y/r) )可以替换为atan2(y/r, x/r)(使用atan2中的libc.math。您也可以通过r

  • cdef int i有助于在getEMField中加快for循环(Cython通常擅长自动获取循环变量的类型,但在这里似乎失败了)

  • 我怀疑逐元素分配E元素比作为一个整体数组更快:

            E[0] = k2*x/0.045 + k10*rn**9*cos(9*arg)
            E[1] = -k2*y/0.045 -k10*rn**9*sin(9*arg)
    
  • 指定listtuple之类的类型并没有太大的价值,它实际上可能会使代码稍微慢一点(因为它会浪费时间检查类型)。

  • 更大的变化是将E和{}作为指针传递到GetEMField中,而不是使用分配它们np.zeros。这将允许您在equationscdef double E[3])中将它们分配为静态C数组。缺点是GetEMField必须是cdef,因此不再可以从Python调用(但是如果愿意,也可以创建一个Python可调用包装函数)。

相关问题 更多 >