NumPy-ndArray中基于布尔的最长数列的更有效求解方法

2024-06-27 02:29:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我搜索我的数组,以找到基于真值的最长序列。有没有一个选项可以在不循环数组的情况下找到最长的序列?你知道吗

我已经写了我自己的解决方案numpy.非零,但可能还有更好的。你知道吗

import numpy as np
arr = np.array([[[1,2,3,4,5],
                [6,7,8,9,10],
                [11,12,13,14,15],
                [16,17,18,19,20],
                [21,22,23,24,25]],
                [[True,True,True,False,True],
                [True,True,True,True,False],
                [True,True,False,True,True],
                [True,True,True,False,True],
                [True,True,True,False,True]]])

def getIndices(arr):
    arr_to_search = np.nonzero(arr)
    arrs = []
    prev_el0 = 0
    prev_el1 = -1
    activ_long = []
    for i in range(len(arr_to_search[0])):
        if arr_to_search[0][i] == prev_el0:
            if arr_to_search[1][i] != prev_el1 + 1:
                arrs.append(activ_long)
                activ_long = []
        else:
            arrs.append(activ_long)
            activ_long = []
        activ_long.append((arr_to_search[0][i],arr_to_search[1][i]))
        prev_el0 = arr_to_search[0][i]
        prev_el1 = arr_to_search[1][i]

    max_len = len(max(arrs,key=len))
    longest_arr_list = [a for a in arrs if len(a) == max_len]
    return longest_arr_list

print(getIndices(arr[1,:,:]))
print(getIndices(arr[1,:,:].T))


[[(1, 0), (1, 1), (1, 2), (1, 3)]]
[[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4)], [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4)]]

Tags: tofalsetruesearchlenifnplong
1条回答
网友
1楼 · 发布于 2024-06-27 02:29:20

下面是一个numpy解决方案,它避免了基于this previous question.的显式循环

我假设布尔数组名为a。本质上,我们找到行从0变为1或从1变为0的索引,并查看它们之间的差异。通过在frong和back上填充0,我们确保从0到1的每个转换都有从1到0的另一个转换。你知道吗

为了方便起见,我同时处理aa.T,但是如果您愿意,可以分别处理它们。你知道吗

m,n = a.shape
A = np.zeros((2*m,n+2))
A[:m,1:-1] = a
A[m:,1:-1] = a.T

dA = np.diff(A)

start = np.where(dA>0)
end = np.where(dA<0)

argmax_run = np.argmax(end[1]-start[1])

row = start[0][argmax_run]
col_start = start[1][argmax_run]
col_end= end[1][argmax_run]-1

max_len = col_end - col_start + 1

print('max run of length {}'.format(max_len))
print('in '+('row' if row<m else'col')+' {}'.format(row%m)+' from '+('col' if row<m else'row')+' {} to {}'.format(col_start,col_end))

为了提高性能和存储,我们可以将A更改为布尔数组。由于上述dA中的-11总是成对出现,我们可以发现startend如下。你知道吗

nz = np.nonzero(dA)
start = (nz[0][::2], nz[1][::2])
end = (nz[0][1::2], nz[1][1::2])

请注意,然后可以完全删除变量startend,因为它们不是真正需要的。你知道吗

相关问题 更多 >