找到第一个位置最有效的方法是什么np.nan公司价值?

2024-09-30 03:24:49 发布

您现在位置:Python中文网/ 问答频道 /正文

考虑数组a

a = np.array([3, 3, np.nan, 3, 3, np.nan])

我可以的

^{pr2}$

但这需要找到所有np.nan,只为找到第一个。
有没有更有效的方法?在


我一直在试图弄清楚是否可以将一个参数传递给np.argpartition,这样np.nanget的排序是第一个,而不是最后一个。在


编辑有关[dup]的内容。
这个问题不同有几个原因。在

  1. 这个问题和答案涉及到价值观的平等。这是关于isnan。在
  2. 这些答案都与我的答案所面临的问题相同。注意,我提供了一个完全有效的答案,但强调了它的低效性。我想解决效率低下的问题。在

编辑有关第二个[dup]的内容。在

仍然强调平等和问题/答案是过时的,而且很可能已经过时。在


Tags: 方法答案编辑内容排序np原因数组
3条回答

我会提名的

a.argmax()

使用@fuglede's测试数组:

^{pr2}$

我没有安装numba,所以可以比较一下。但是我相对于short的加速比大于@fuglede's6x

我在Py3中测试,它接受<np.nan,而Py2会发出运行时警告。但是代码搜索表明这并不依赖于这种比较。在

{{{cd6}的操作取决于{cd6}的末尾。在

numpy/core/src/multiarray/arraytypes.c.src中,它看起来像是BOOL_argmax短路,一旦遇到True就会立即返回。在

for (; i < n; i++) {
    if (ip[i]) {
        *max_ind = i;
        return 0;
    }
}

并且@fname@_argmax也在最大nan上短路。np.nan也是argmin中的“最大值”。在

#if @isfloat@
    if (@isnan@(mp)) {
        /* nan encountered; it's maximal */
        return 0;
    }
#endif

欢迎来自经验丰富的c程序员的评论,但在我看来,至少对于np.nan,一个简单的argmax将与我们所能得到的一样快。在

在生成a时使用9999表明a.argmax时间依赖于该值,与短路一致。在

下面是一个使用itertools.takewhile()的python方法:

from itertools import takewhile
sum(1 for _ in takewhile(np.isfinite, a))

在{}方法中使用生成器表达式进行基准测试:1

^{pr2}$

但仍然(到目前为止)慢于numpy方法:

In [119]: %timeit np.isnan(a).argmax()
100000 loops, best of 3: 16.8 µs per loop

1这种方法的问题是使用enumerate函数。它首先从numpy数组返回一个enumerate对象(它是一个类似迭代器的对象),调用生成器函数和迭代器的next属性需要时间。

研究numba.jit;如果没有它,矢量化版本在大多数情况下可能会击败直接的纯Python搜索,但是在编译代码之后,普通搜索将占据主导地位,至少在我的测试中是这样:

In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])

In [70]: %paste
import numba

def naive(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

def short(a):
        return np.isnan(a).argmax()

@numba.jit
def naive_jit(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

@numba.jit
def short_jit(a):
        return np.isnan(a).argmax()
## -- End pasted text --

In [71]: %timeit naive(a)
100 loops, best of 3: 7.22 ms per loop

In [72]: %timeit short(a)
The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 37.7 µs per loop

In [73]: %timeit naive_jit(a)
The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.79 µs per loop

In [74]: %timeit short_jit(a)
The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 144 µs per loop

编辑:正如@hpaulj在他们的回答中指出的,numpy实际上提供了一个优化的短路搜索,其性能与上面的JITted搜索相当:

^{pr2}$

相关问题 更多 >

    热门问题