编译函数的njit nopython版本由于数据类型而无法成功

2024-10-06 08:15:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我在njit写了一个函数来加速一个非常慢的水库运行优化代码。该函数根据水库水位和闸门可用性返回溢出释放的最大值。我传入了一个参数大小,它指定要计算的流的数量(在某些调用中它是一个,在一些调用中是多个)。我还路过一个数字0数组,然后用函数输出填充。函数的简化版本如下:

import numpy as np
from numba import njit

@njit(cache=True)
def fncMaxFlow(elev, flag, size, MaxQ):
    if (flag == 1): # SPOG2 running
        if size==0:
            if (elev>367.28):
                return 861.1 
            else: return 0
        else:
            for i in range(size):
                if((elev[i]>367.28) & (elev[i]<385)):
                    MaxQ[i]=861.1
            return MaxQ
    else:
        if size==0: return 0
        else: return MaxQ

fncMaxFlow(np.random.randint(368, 380, 3), 1, 3, np.zeros(3))

我得到的错误是:

^{pr2}$

这是什么原因?有没有什么解决办法或者我遗漏了一些步骤,这样我就可以使用numba来加快速度?这个函数和其他类似的函数被调用了数百万次,所以它们是计算效率的主要因素。任何建议都会有帮助-我对python很陌生。在


Tags: 函数代码importsizereturnifnpelse
1条回答
网友
1楼 · 发布于 2024-10-06 08:15:30

numba函数中的变量必须具有包括返回变量在内的一致类型。在代码中,可以返回MaxQ(数组)、861.1(浮点)或0(整数)。在

您需要重构此代码,以便无论代码路径如何,它始终返回一致的类型。在

还要注意,在将numpy数组与标量(elev > 367.28)进行比较的几个地方,您得到的是一个布尔值数组,这将导致您的问题。由于这个原因,您的示例函数不能作为纯python函数运行(删除numba修饰符)。在

相关问题 更多 >