我在njit写了一个函数来加速一个非常慢的水库运行优化代码。该函数根据水库水位和闸门可用性返回溢出释放的最大值。我传入了一个参数大小,它指定要计算的流的数量(在某些调用中它是一个,在一些调用中是多个)。我还路过一个数字0数组,然后用函数输出填充。函数的简化版本如下:
import numpy as np
from numba import njit
@njit(cache=True)
def fncMaxFlow(elev, flag, size, MaxQ):
if (flag == 1): # SPOG2 running
if size==0:
if (elev>367.28):
return 861.1
else: return 0
else:
for i in range(size):
if((elev[i]>367.28) & (elev[i]<385)):
MaxQ[i]=861.1
return MaxQ
else:
if size==0: return 0
else: return MaxQ
fncMaxFlow(np.random.randint(368, 380, 3), 1, 3, np.zeros(3))
我得到的错误是:
^{pr2}$这是什么原因?有没有什么解决办法或者我遗漏了一些步骤,这样我就可以使用numba来加快速度?这个函数和其他类似的函数被调用了数百万次,所以它们是计算效率的主要因素。任何建议都会有帮助-我对python很陌生。在
numba函数中的变量必须具有包括返回变量在内的一致类型。在代码中,可以返回
MaxQ
(数组)、861.1(浮点)或0(整数)。在您需要重构此代码,以便无论代码路径如何,它始终返回一致的类型。在
还要注意,在将numpy数组与标量(
elev > 367.28
)进行比较的几个地方,您得到的是一个布尔值数组,这将导致您的问题。由于这个原因,您的示例函数不能作为纯python函数运行(删除numba修饰符)。在相关问题 更多 >
编程相关推荐