使用numba加速求解ivp

2024-09-29 23:28:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用带有“solve_ivp”的线的方法来解非线性PDE

@njit(fastmath=True,error_model="numpy",cache=True)
def thinFilmEq(t,h,dx,Ma,phiFun,tempFun):
    phi = phiFun(h)
    temperature = tempFun(h)
    hxx = (np.roll(h,1) - 2*h + np.roll(h,-1))/dx**2  # use np.roll as I'm implementing periodic BC
    p = phi - hxx
    px = (np.roll(p,-1) - np.roll(p,1))/(2*dx)
    Tx = (np.roll(temperature,-1) - np.roll(temperature,1))/(2*dx)

    flux = h**3*px/3 + Ma*h**2*Tx/2
    dhdt = (np.roll(flux,-1) - np.roll(flux,1))/(2*dx)

    return dhdt

我得到以下错误:TypingError: non-precise type pyobject [1] During: typing of argument at C:/Users/yhcha/method_of_lines/test_01_thinFilmEq.py (28)我怀疑这是由于phiFuntempFun造成的。它们是我在调用时提供的函数。我将函数参数设置为dhdt函数,只是为了让事情更一般。当我试图删除phiFuntempFun并显式地给出thinFilmEq内的函数形式时,错误消失了

然后,我看到下面的错误TypingError: Use of unsupported NumPy function 'numpy.roll' or unsupported use of the function.,我认为可能np.roll不受支持,尽管它包含在官方的website中。在处理周期BC的有限差分时,我试图“放大”数组,以某种方式手动应用与np.roll相同的东西:

    def augment(x):
        x2 = np.empty(len(x)+2)
        x2[1:-1] = x
        x2[0] = x[-1]
        x2[-1] = x[0]
        return x2

    H = augment(x)
    hx = (H[2:]-[H:-2])/dx   # use this instead of hx=(roll(h,-1)-roll(h,1))/dx

我的问题是:

  1. 似乎我可以让numba工作,但代价是使代码不那么通用(不能提供像phiFun这样的任意函数和优雅(例如,不能使用带有np.roll的一行程序)。有没有办法绕过它,或者这只是我使用numba来“编译”代码时需要付出的代价?

  2. 没有numba的原始版本比我编写的Matlab版本慢近10倍,而numba版本仍然比Matlab慢约3-4倍。我真的不希望scipy的性能比Matlab好,但是还有其他方法可以加速代码以弥补差距吗?


Tags: of函数use错误npx2rolltemperature

热门问题