Python Numba/NumPy中实现的摊余O(1)滚动最小值

2024-10-03 06:18:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图实现一个滚动最小值,它有一个摊余的O(1)get_min()。{O)来自于算法^(O)


原始功能:

import pandas as pd
import numpy as np
from numba import njit, prange

def rolling_min_original(data, n):
    return pd.Series(data).rolling(n).min().to_numpy()

我尝试实现摊余的O(1)get_min()算法:(对于非小的n,这个函数有很好的性能)

^{pr2}$

n很小时,这是一个幼稚的实现:

@njit(parallel= True)
def rolling_min_smalln(data, n):
    result= np.empty(len(data), dtype= data.dtype)

    for i in prange(n-1):
        result[i]= np.nan

    for i in prange(n-1, len(data)):
        result[i]= data[i-n+1: i+1].min()

    return result

一些用于测试的小代码

def remove_nan(arr):
    return arr[~np.isnan(arr)]

if __name__ == '__main__':

    np.random.seed(0)
    data_size = 200000
    data = np.random.uniform(0,1000, size = data_size)+29000

    w_size = 37

    r_min_original= rolling_min_original(data, w_size)
    rmin1 = rollin_min(data, w_size)

    r_min_original = remove_nan(r_min_original)
    rmin1 = remove_nan(rmin1)

    print(np.array_equal(r_min_original,rmin1))

函数rollin_min()具有几乎恒定的运行时间,并且在n较大时比rolling_min_original()低,这是很好的。但是当n较低时(在我的电脑中,n < 37左右,在这个范围内{}很容易被一个幼稚的实现rolling_min_smalln()击败)时,它的性能就很差。在

我正在努力寻找改善rollin_min()的方法,但到目前为止我被卡住了,这就是为什么我在这里寻求帮助。在


我的问题如下:

我正在实现的算法是滚动/滑动窗口最小值/最大值的最佳算法吗?在

如果没有,最好/更好的算法是什么?如果是这样,我如何从算法的角度进一步改进函数?在

除了算法本身,还有什么方法可以进一步提高函数rollin_min()的性能?在


编辑:根据多个请求将我的最新答案移至“答案”部分


Tags: 函数import算法datasizedefnpresult
2条回答

根据多个请求将其从问题编辑部分移到此处。在

受Matt Timmermans在答案中给出的更简单实现的启发,我制作了一个cpu多核版本的rolling min。代码如下:

@njit(parallel= True)
def rollin_min2(data, n):
    """
    1) make a loop that iterates over K sections of n elements; each section is independent so that it can benefit from multicores cpu 
    2) for each iteration of sections, generate backward local minimum(sec_min2) and forward minimum(sec_min1)

    say m=8, len(data)= 23, then we only need the idx= (reversed to 7,6,5,...,1,0 (0 means minimum up until idx= 0)),
    1st iter
    result[7]= min_until 0,
    result[8]= min(min(data[7:9]) and min_until 1),
    result[9]= min(min(data[7:10]) and m_til 2)
    ...
    result[14]= min(min(data[7:15]) and m_til 7) 

    2nd iter
    result[15]= min_until 8,
    result[16]= min(min(data[15:17]) and m_til 9),
    result[17]= min(min(data[15:18]) and m_til 10)
    ...
    result[22]= min(min(data[15:23]) and m_til 15) 


    """
    ar_len= len(data)

    sec_min1= np.empty(ar_len, dtype = data.dtype)
    sec_min2= np.empty(ar_len, dtype = data.dtype)

    for i in prange(n-1):
        sec_min1[i]= np.nan

    for sec in prange(ar_len//n):
        s2_min= data[n*sec+ n-1]
        s1_min= data[n*sec+ n]

        for i in range(n-1,-1,-1):
            if data[n*sec+i] < s2_min:
                s2_min= data[n*sec+i]
            sec_min2[n*sec+i]= s2_min

        sec_min1[n*sec+ n-1]= sec_min2[n*sec]

        for i in range(n-1):
            if n*sec+n+i < ar_len:
                if data[n*sec+n+i] < s1_min:
                    s1_min= data[n*sec+n+i]
                sec_min1[n*sec+n+i]= min(s1_min, sec_min2[n*sec+i+1])

            else:
                break

    return sec_min1 

实际上,我花了一个小时测试了rolling min的各种实现,在我的6C/12T笔记本电脑中,这个多核版本在n是“中等尺寸”时效果最好。但是,当n至少是源数据长度的30%时,其他实现就开始显得格外突出。一定有更好的方法来改进这个功能,但是在编辑的时候我还不知道这些方法。在

代码缓慢的主要原因可能是在mingetter_rev中分配了一个新数组。您应该在整个过程中重用相同的存储。在

然后,因为实际上不必实现队列,所以可以进行更多的优化。例如,两个堆栈的大小最多(通常)为n,因此您可以将它们保持在同一个数组中,大小为n。从开始处增加一个,从末尾增加一个。在

你会注意到有一个非常规则的模式-从头到尾依次填充数组,从末尾重新计算最小值,在填充数组时生成输出,重复。。。在

这导致了一个实际上更简单的算法,其解释更简单,根本不涉及堆栈。下面是一个实现,并对其工作方式进行了注释。请注意,我没有费心在开头加上NaNs:

def rollin_min(data, n):

    #allocate the result.  Note the number valid windows is len(data)-(n-1)
    result = np.empty(len(data)-(n-1), data.dtype)

    #every nth position is a "mark"
    #every window therefore contains exactly 1 mark
    #the minimum in the window is the minimum of:
    #  the minimum from the window start to the following mark; and
    #  the minimum from the window end the the preceding (same) mark

    #calculate the minimum from every window start index to the next mark
    for mark in range(n-1, len(data), n):
        v = data[mark]
        if (mark < len(result)):
            result[mark] = v
        for i in range(mark-1, mark-n, -1):
            v = min(data[i],v)
            if (i < len(result)):
                result[i] = v

    #for each window, calculate the running total from the preceding mark
    # to its end.  The first window ends at the first mark
    #then combine it with the first distance to get the window minimum

    nextMarkPos = 0
    for i in range(0,len(result)):
        if i == nextMarkPos:
             v = data[i+n-1]
             nextMarkPos += n
        else:
            v = min(data[i+n-1],v)
        result[i] = min(result[i],v)

    return result

相关问题 更多 >