作为一个学习项目,我正在做一个个人项目,用Python编写一个quadcopter仿真(和控制)。我使用的是scipy
积分器odeint
,在长时间的计算中我非常失望。所以我希望使用numba
来加速我的集成。我调用odeint
每个timestep,因为我必须在每个模拟的timestep之后创建命令。在
起初,当我要集成的函数(state_dot
)是Quadcopter
类的方法时,我遇到了一些问题。所以我把它作为一个单独的函数,但是当我用@jit
修饰我的函数时,我在定义正确的类型时遇到了问题。state_dot
函数有一个dictionary(params
)作为输入参数(我读过numba支持字典),但它也是一个自定义类(wind
),因为我的wind模型是该类的一个方法。如果我暂时排除wind
,那么使用numba.typed.Dict
导入字典似乎不起作用。在
为了在函数中导入wind
对象,我看到使用了numba类型object_
,但是Python在numba中没有找到object_
。在
我使用的是numba版本0.45.0和python3.7。在
import numpy as np
from scipy.integrate import odeint
from numba import jit, void, float_, int_
import numba
class Quadcopter:
def __init__(self):
# Quad Params
# ---------------------------
mB = 1.2 # mass (kg)
params = {}
params["mB"] = mB
self.params = params
# Initial State
# ---------------------------
self.state = np.zeros(3)
def update(self, t, Ts, cmd, wind):
self.state = odeint(state_dot, self.state, [t,t+Ts], args = (cmd, self.params, wind))[1]
@jit(void(float_[:], float_, float_[:], numba.typed.Dict )) #(nopython = True)
def state_dot(state, t, cmd, params, wind):
# Import Params
# ---------------------------
mB = params["mB"]
# Import State Vector
# ---------------------------
x = state[0]
y = state[1]
z = state[2]
# Motor Dynamics and Rotor forces (Second Order System: https://apmonitor.com/pdc/index.php/Main/SecondOrderSystems)
# ---------------------------
print(cmd)
# Wind Model
# ---------------------------
[velW, qW1, qW2] = wind.randomWind(t)
print(velW)
# State Derivative Vector
# ---------------------------
sdot = np.zeros(3)
sdot[0] = x*t + 0.1
sdot[1] = y*t + 0.1
sdot[2] = z*t + 0.1
return sdot
class Wind:
def __init__(self):
# Normally, average wind would be randomly set here
self.velW_med = 5.0
self.qW1_med = 0.2
self.qW2_med = 0.1
def randomWind(self, t):
# Normally, wind values would be a sine function dependant of current time
velW = self.velW_med
qW1 = self.qW1_med
qW2 = self.qW2_med
return velW, qW1, qW2
# Set time
Ti = 0
Ts = 0.005
Tf = 10
# Initialize quadcopter and wind
quad = Quadcopter()
wind = Wind()
# Simulation
t = Ti
while round(t,3) < Tf:
cmd = np.array([1,2,1,3])
quad.update(t, Ts, cmd, wind)
print(quad.state)
t += Ts
收到的错误是
^{pr2}$完整的代码可以在这里查看:https://github.com/bobzwik/Quadcopter_SimCon/blob/dev_numba/Simulation/quadFiles/quad.py
如果我遗漏了任何信息,请随时询问。在
编辑:更改了链接的完整代码,以链接到另一个分支。在
我注意到的第一件事是——至少在您这里展示的代码中——您的jit签名有四种类型,但是您要装饰的函数有五个参数:
所以很明显你需要解决这个问题。最简单的办法就是去掉签名,让numba弄明白:
^{pr2}$当然,即使您这样做,numba仍然抱怨它不知道如何键入所有内容,并指向
mB = params["mB"]
的行。它仍然执行“循环提升”,这意味着它能够编译一些东西,但速度不会尽可能快。在所以第二件要注意的事情是,虽然numba说它支持
dict
s,但随后又提出了许多警告。基本上,使用dict仍然不是一个好主意。我也看不出有什么好的理由让你使用dict,为什么不让mB
成为你类的一员,就像self.mB = mB
一样?我知道在完整的Quadcopter
类中会有更复杂的事情,但是可以有很多成员。在现在,要注意的第三件事是,自从我写了that gist you pointed out elsewhere以来,numba已经变得更好了,现在可以处理类了,所以您可能需要研究一下^{} 。一般来说,当你把一个jitclass对象传递给你想要jit的函数时,numba会知道如何处理它。在
但可能比所有这些更重要的是,} 对象来在步骤之间保持所有的设置,并保持它在周围,以便在步骤之间使用相同的对象。较新的接口^{} 和^{} (以及类似的)大致上分别相当于},只是{}有我的首选解算器
update
方法对每一步都调用odeint
。我猜这是你代码中最慢的部分。这个函数应该被调用一次,这样它就可以从头到尾解决你的整个问题,因此它有很多(相对缓慢的)开销,这些开销与理解传递的参数、分配内存、初始化事物有关,一个更好的方法是构造一个^{odeint
和{dop853
。如果您只需要OdeSolver
子类中的一个,那么您可能更喜欢这些接口。还请注意,如果您在步骤之间实际更改了状态中的任何内容,则可能需要再次调用set_initial_value
,否则可能会在您不注意的情况下出错。在更一般地说,如果你担心速度,你能做的最好的事情就是分析你的代码。这里的第一步是在ipython中使用^{} 。在
相关问题 更多 >
编程相关推荐