用于SciPy积分和插值的Numba

2024-09-28 17:21:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用了Numba来加速我的代码。它工作得很好,提供了2-3倍的因数改善。然而,在我的代码中花费的主要时间(大约90%)是scipy四次积分和插值(线性和三次样条曲线)。我做了几百次这些积分,所以我认为这是Numba可以提高的。 看起来Numba不支持这些吗?我听说过Numba Scipy,它应该让Numba识别Scipy,但这似乎仍然不起作用。 有没有办法让Numba优化我的积分/插值


Tags: 代码时间线性scipy曲线样条插值花费
1条回答
网友
1楼 · 发布于 2024-09-28 17:21:34

刚刚为cquadpack编写了一个名为NumbaQuadpack的包装器,它应该做您想做的事情:https://github.com/Nicholaswogan/NumbaQuadpack。cquadpack是Quadpack的C版本,这是scipy.integrate.quad使用的

from NumbaQuadpack import quadpack_sig, dqags
import numpy as np
import numba as nb
import timeit

@nb.cfunc(quadpack_sig)
def f(x, data):
    return x**2 + 2 + np.log(x)
funcptr = f.address
a = 0
b = 1
sol, abserr, success = dqags(funcptr, a, b)
print(sol) # definite integral solution

# test speed
@nb.njit()
def timetest_nb():
    sol, abserr, success = dqags(funcptr, a, b)
timetest_nb()
n_time=10000
print(timeit.Timer(timetest_nb).timeit(number=n_time)/n_time) 

在我的计算机上,这个小积分需要4.2µs,而当我使用scipy.integrate.quad做同样的事情时,需要68.1µs

对于插值,只需使用np.interp(1d插值)。它可以在numba jitted函数中使用

通常,任何C/C++或fortran代码都可以用CType包装,并从numba jitted函数中调用

相关问题 更多 >