使用numba加速odeint,在试图传递字典和自定义obj时出现问题

2024-10-04 01:35:22 发布

您现在位置:Python中文网/ 问答频道 /正文

作为一个学习项目,我正在做一个个人项目,用Python编写一个quadcopter仿真(和控制)。我使用的是scipy积分器odeint,在长时间的计算中我非常失望。所以我希望使用numba来加速我的集成。我调用odeint每个timestep,因为我必须在每个模拟的timestep之后创建命令。在

起初,当我要集成的函数(state_dot)是Quadcopter类的方法时,我遇到了一些问题。所以我把它作为一个单独的函数,但是当我用@jit修饰我的函数时,我在定义正确的类型时遇到了问题。state_dot函数有一个dictionary(params)作为输入参数(我读过numba支持字典),但它也是一个自定义类(wind),因为我的wind模型是该类的一个方法。如果我暂时排除wind,那么使用numba.typed.Dict导入字典似乎不起作用。在

为了在函数中导入wind对象,我看到使用了numba类型object_,但是Python在numba中没有找到object_。在

我使用的是numba版本0.45.0和python3.7。在

import numpy as np
from scipy.integrate import odeint
from numba import jit, void, float_, int_
import numba

class Quadcopter:

    def __init__(self):

        # Quad Params
        # ---------------------------
        mB  = 1.2       # mass (kg)
        params = {}
        params["mB"]   = mB
        self.params = params


        # Initial State
        # ---------------------------
        self.state = np.zeros(3)

    def update(self, t, Ts, cmd, wind):

        self.state = odeint(state_dot, self.state, [t,t+Ts], args = (cmd, self.params, wind))[1]


@jit(void(float_[:], float_, float_[:], numba.typed.Dict )) #(nopython = True)
def state_dot(state, t, cmd, params, wind):

    # Import Params
    # ---------------------------    
    mB   = params["mB"]

    # Import State Vector
    # ---------------------------  
    x      = state[0]
    y      = state[1]
    z      = state[2]

    # Motor Dynamics and Rotor forces (Second Order System: https://apmonitor.com/pdc/index.php/Main/SecondOrderSystems)
    # ---------------------------
    print(cmd)

    # Wind Model
    # ---------------------------
    [velW, qW1, qW2] = wind.randomWind(t)
    print(velW)

    # State Derivative Vector
    # ---------------------------
    sdot     = np.zeros(3)
    sdot[0]  = x*t + 0.1
    sdot[1]  = y*t + 0.1
    sdot[2]  = z*t + 0.1


    return sdot


class Wind:

    def __init__(self):

        # Normally, average wind would be randomly set here
        self.velW_med = 5.0
        self.qW1_med  = 0.2
        self.qW2_med  = 0.1

    def randomWind(self, t):

        # Normally, wind values would be a sine function dependant of current time
        velW = self.velW_med
        qW1  = self.qW1_med
        qW2  = self.qW2_med

        return velW, qW1, qW2

# Set time
Ti = 0
Ts = 0.005
Tf = 10

# Initialize quadcopter and wind
quad = Quadcopter()
wind = Wind()

# Simulation
t = Ti
while round(t,3) < Tf:
    cmd = np.array([1,2,1,3])
    quad.update(t, Ts, cmd, wind)
    print(quad.state)
    t += Ts

收到的错误是

^{pr2}$

完整的代码可以在这里查看:https://github.com/bobzwik/Quadcopter_SimCon/blob/dev_numba/Simulation/quadFiles/quad.py

如果我遗漏了任何信息,请随时询问。在

编辑:更改了链接的完整代码,以链接到另一个分支。在


Tags: 函数selfcmddefmedmbparamsstate
1条回答
网友
1楼 · 发布于 2024-10-04 01:35:22

我注意到的第一件事是——至少在您这里展示的代码中——您的jit签名有四种类型,但是您要装饰的函数有五个参数:

@jit(void(float_[:], float_, float_[:], numba.typed.Dict))
def state_dot(state, t, cmd, params, wind):

所以很明显你需要解决这个问题。最简单的办法就是去掉签名,让numba弄明白:

^{pr2}$

当然,即使您这样做,numba仍然抱怨它不知道如何键入所有内容,并指向mB = params["mB"]的行。它仍然执行“循环提升”,这意味着它能够编译一些东西,但速度不会尽可能快。在

所以第二件要注意的事情是,虽然numba说它支持dicts,但随后又提出了许多警告。基本上,使用dict仍然不是一个好主意。我也看不出有什么好的理由让你使用dict,为什么不让mB成为你类的一员,就像self.mB = mB一样?我知道在完整的Quadcopter类中会有更复杂的事情,但是可以有很多成员。在

现在,要注意的第三件事是,自从我写了that gist you pointed out elsewhere以来,numba已经变得更好了,现在可以处理类了,所以您可能需要研究一下^{}。一般来说,当你把一个jitclass对象传递给你想要jit的函数时,numba会知道如何处理它。在

但可能比所有这些更重要的是,update方法对每一步都调用odeint。我猜这是你代码中最慢的部分。这个函数应该被调用一次,这样它就可以从头到尾解决你的整个问题,因此它有很多(相对缓慢的)开销,这些开销与理解传递的参数、分配内存、初始化事物有关,一个更好的方法是构造一个^{}对象来在步骤之间保持所有的设置,并保持它在周围,以便在步骤之间使用相同的对象。较新的接口^{}^{}(以及类似的)大致上分别相当于odeint和{},只是{}有我的首选解算器dop853。如果您只需要OdeSolver子类中的一个,那么您可能更喜欢这些接口。还请注意,如果您在步骤之间实际更改了状态中的任何内容,则可能需要再次调用set_initial_value,否则可能会在您不注意的情况下出错。在

更一般地说,如果你担心速度,你能做的最好的事情就是分析你的代码。这里的第一步是在ipython中使用^{}。在

相关问题 更多 >