在某些条件下,如何有效地使用numpy遍历数组来查找模式?

2024-10-04 01:30:33 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我有一个简单的一维数组,比如:

[0,0,1,0,1,0,0,0,0,1,1,0,0,0,0,1,0,0,0,1,0,0,1,1,1,1,0,1,0,1,0,1,0,0,0,0,1,0,1,1,1,0,1,1,0,0,0]

我想知道使用numpy查找某个模式结束处的索引的最有效或最快的方法。我想找到的模式由两部分组成。第一部分包括所有内容,直到找到至少第n个连续的1(假设该阈值为3)。在那之后,这个模式应该被认为是在另一个连续的0被发现之后完成的。发生错误时,模式结束处的索引应存储在数组中以供以后使用。你知道吗

我可能没有很好地描述它,所以这里有一些代码可以实现我对上面相同数组的期望。你知道吗

import numpy as np

arr = np.array([0,0,1,0,1,0,0,0,0,1,1,1,0,0,0,0,0,1,0,0,0,1,0,0,1,1,1,1,0,1,0,1,0,0,0,0,0,1,0,1,1,0,1,1,0,1,1,0,0,0])

patternFound = False
threshold = 3
nonzerosCount = 0
zerosCount = 0
split_indexes=[]

for i in range(len(arr)):
    if patternFound:
        if arr[i] <= 0:
            zerosCount += 1
        else:
            zerosCount = 0

        if zerosCount >= threshold and i+1 != len(arr):
            zerosCount = 0
            patternFound=False
            split_indexes.append(i+1)
    else:
        if arr[i] >= 1:
            nonzerosCount += 1
        else:
            nonzerosCount = 0

        if nonzerosCount >= threshold:
            nonzerosCount = 0
            patternFound = True

print "Indexes:",  split_indexes            
print "Split:", 
for arr in np.split(arr, split_indexes):
    print arr,',',

结果是:

索引:[15,35] 拆分:[0 0 1 0 1 0 0 0 0 0 1 1 1 0 0 0 0],[0 0 1 0 0 0 1 0 1 1 1 1 0 1 0 0 0],[0 0 1 0 1 0 1 0 1 0 0 0]

这对于像我示例中的小数组很好。但是,我想知道使用numpy实现这一点的更有效的方法。例如,如果我只想对一个更大的数组求和

arr = np.random.uniform(size=1000000)

我只是重复了一遍:

total = 0
for i in arr:
    total += i

它比:

np.sum(arr)

Tags: innumpyforthresholdifnp模式数组
2条回答

不确定这是否有助于提高速度,但您可以尝试使用:

np.logical_or(np.logical_or(arr[:-2], arr[1:-1]), arr[2:])

检测3个连续的0(查找False

以及

np.logical_and(np.logical_and(arr[:-2], arr[1:-1]), arr[2:])

检测3个连续的1(寻找True

可以使用Pythran自动将代码转换为本机高效版本(显式迭代NumPy数组元素是一个性能瓶颈)。你知道吗

比如:

#pythran export pattern(bool [])
import numpy as np
def pattern(arr):
    patternFound = False
    threshold = 3
    nonzerosCount = 0
    zerosCount = 0
    split_indexes=[]

    for i in range(len(arr)):
        if patternFound:
            if arr[i] <= 0:
                zerosCount += 1
            else:
                zerosCount = 0

            if zerosCount >= threshold and i+1 != len(arr):
                zerosCount = 0
                patternFound=False
                split_indexes.append(i+1)
        else:
            if arr[i] >= 1:
                nonzerosCount += 1
            else:
                nonzerosCount = 0

            if nonzerosCount >= threshold:
                nonzerosCount = 0
                patternFound = True
    split_indexes = np.asarray(split_indexes)
    return split_indexes, np.split(arr, split_indexes)

pythran pattern.py编译。 很好用。你知道吗

没有Pythran:

% python -m timeit -s 'import pattern, numpy; arr = numpy.asarray(numpy.random.choice([0, 1], size=1000000), dtype=bool)' 'pattern.pattern(arr)'
10 loops, best of 3: 3.11 sec per loop

与Pytran:

% python -m timeit -s 'import pattern, numpy; arr = numpy.asarray(numpy.random.choice([0, 1], size=100000), dtype=bool)' 'pattern.pattern(arr)'
1000 loops, best of 3: 880 usec per loop

相关问题 更多 >