一种更有效的在数组中剪裁值的方法?

2024-05-27 11:18:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我有以下代码:

import numpy as np

MIN = -3
MAX = 3

def bound(array):
    return np.array([(0 if v == 0 else (MIN if v < 0 else 1)) for v in array]),\
           np.array([(0 if v == 0 else (-1 if v < 0 else MAX)) for v in array])

print(bound(np.array(range(MIN, MAX + 1))))

返回:

(array([-3, -3, -3,  0,  1,  1,  1]), array([-1, -1, -1,  0,  3,  3,  3]))

我的实际数组比这个大得多,但它由从最小到最大的整数组成(在本例中为-3到3)

最小值和最大值不被视为与0对称,但应保持0的值

有没有一种更高效/更快(cpu时间)的方法可以做到这一点? 非常感谢时间比较

谢谢


Tags: 代码inimportnumpyforifdefas
2条回答

使用np.where

def bound_where(array):
  return np.where(array==0, 0, np.where(array<0, MIN, 1)), \
           np.where(array==0, 0, np.where(array<0, -1, MAX))

与其他方法的比较

测试

方法

  1. 张贴方法
  2. 选择方法(来自@QuangHoang answer的边界2)
  3. Where方法(当前答案)

测试代码

arr = np.random.randint(-10,10,1000)
count = 1000
print(f'Posted method: {timeit(lambda:bound(arr), number=count):.4f}')
print(f'Select method: {timeit(lambda:bound2(arr), number=count):.4f}')
print(f'Where mehtod: {timeit(lambda:bound_where(arr), number=count):.4f}')

结果(秒)

Posted method: 6.1951
Select method: 0.3959
Where method: 0.1466

方法最快的地方在哪里

我会使用np.select

def bound2(arr):
    pos_neg = arr>0, arr<0
    return (
        np.select(pos_neg, (1,MIN),0),
        np.select(pos_neg, (MAX,-1),0)
    )

测试时间:

# sample data
arr = np.random.randint(-10,10,1000)

%%timeit -n 100
bound(arr)
# 858 µs ± 28.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit -n 100
bound2(arr)
# 59.9 µs ± 4.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

相关问题 更多 >

    热门问题