沿排序二维numpy数组轴寻找第一个非零值

2024-09-27 07:28:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图找到最快的方法来找到二维排序数组中每一行的第一个非零值。从技术上讲,数组中唯一的值是0和1,它是“排序的”。

例如,数组可以如下所示:

v=

0 0 0 1 1 1 1 
0 0 0 1 1 1 1 
0 0 0 0 1 1 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 0

我可以使用argmax函数

argmax(v, axis=1))

当它从0变为1时查找,但我相信这将对每一行进行详尽的搜索。我的阵列大小是合理的(~2000x2000)。对于for循环中的每一行,argmax的性能是否仍然优于仅对其执行searchsorted方法,或者是否有更好的替代方法?

此外,数组始终是这样的:行的第一个位置总是>;=其上一行的第一个位置(但不能保证最后几行中会有一个)。我可以用for循环和“起始索引值”来利用它,每一行的起始索引值等于前一行的第一个1的位置,但我是否正确地认为numpy argmax函数仍将优于用python编写的循环。

我只需要对备选方案进行基准测试,但是阵列的边长可能会有很大的变化(从250到10000)。


Tags: 方法函数gtnumpy利用for排序方案
2条回答

使用np.where相当快:

>>> a
array([[0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0]])
>>> np.where(a>0)
(array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 5]), array([3, 4, 5, 6, 3, 4, 5, 6, 4, 5, 6, 6, 6, 6]))

传递值的to坐标大于0的元组。

也可以使用np.where测试每个子数组:

def first_true1(a):
    """ return a dict of row: index with value in row > 0 """
    di={}
    for i in range(len(a)):
        idx=np.where(a[i]>0)
        try:
            di[i]=idx[0][0]
        except IndexError:
            di[i]=None    

    return di       

印刷品:

{0: 3, 1: 3, 2: 4, 3: 6, 4: 6, 5: 6, 6: None}

即,第0行:索引3>;0;第4行:索引4>;0;第6行:没有索引大于0

正如您所怀疑的,argmax可能更快:

def first_true2():
    di={}
    for i in range(len(a)):
        idx=np.argmax(a[i])
        if idx>0:
            di[i]=idx
        else:
            di[i]=None    

    return di       
    # same dict is returned...

如果您可以处理不为所有零行设置None的逻辑,则速度更快:

def first_true3():
    di={}
    for i, j in zip(*np.where(a>0)):
        if i in di:
            continue
        else:
            di[i]=j

    return di      

下面是一个在argmax中使用axis的版本(如您在评论中建议的那样):

def first_true4():
    di={}
    for i, ele in enumerate(np.argmax(a,axis=1)):
        if ele==0 and a[i][0]==0:
            di[i]=None
        else:
            di[i]=ele

    return di          

对于速度比较(在示例数组中),我得到:

            rate/sec usec/pass first_true1 first_true2 first_true3 first_true4
first_true1   23,818    41.986          --      -34.5%      -63.1%      -70.0%
first_true2   36,377    27.490       52.7%          --      -43.6%      -54.1%
first_true3   64,528    15.497      170.9%       77.4%          --      -18.6%
first_true4   79,287    12.612      232.9%      118.0%       22.9%          --

如果我把它扩展到一个2000 X 2000 np阵列,我得到的是:

            rate/sec  usec/pass first_true3 first_true1 first_true2 first_true4
first_true3        3 354380.107          --       -0.3%      -74.7%      -87.8%
first_true1        3 353327.084        0.3%          --      -74.6%      -87.7%
first_true2       11  89754.200      294.8%      293.7%          --      -51.7%
first_true4       23  43306.494      718.3%      715.9%      107.3%          --

argmax()使用C级循环,比Python循环快得多,所以我想即使你用Python编写了一个智能算法,也很难打败argmax(),你可以使用Cython来加速:

@cython.boundscheck(False)
@cython.wraparound(False) 
def find(int[:,:] a):
    cdef int h = a.shape[0]
    cdef int w = a.shape[1]
    cdef int i, j
    cdef int idx = 0
    cdef list r = []
    for i in range(h):
        for j in range(idx, w):
            if a[i, j] == 1:
                idx = j
                r.append(idx)
                break
        else:
            r.append(-1)
    return r

在我的2000x2000矩阵电脑上,是100us对3ms

相关问题 更多 >

    热门问题