使用numpy和numba Python优化计算

2024-10-06 12:39:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用numba和numpy使python更快地运行标准偏差函数。然而,问题是for循环非常慢,我需要其他方法,以便使代码更快。我将numba迭代到已经存在的numpy版本,但是在性能上没有太大的提高。我最初的list_中有数百万个值,因此计算标准偏差函数需要很长时间。下面的list_函数是一个非常短的numpy数组,它是解决我的问题的一个示例,因为我无法发布原始列表编号。下面函数中的for循环计算下面list_中变量number定义的每n个数字的标准偏差。如何使当前函数运行得更快

import numpy as np
from numba import njit,jit,vectorize

number = 5
list_= np.array([457.334015,424.440002,394.795990,408.903992,398.821014,402.152008,435.790985,423.204987,411.574005,
404.424988,399.519989,377.181000,375.467010,386.944000,383.614990,375.071991,359.511993,328.865997,
320.510010,330.079010,336.187012,352.940002,365.026001,361.562012,362.299011,378.549011,390.414001,
400.869995,394.773010,382.556000])

正常代码:

def std_():
    std = np.array([list_[i:i+number].std() for i in range(0, len(list_)-number)])
    print(std)
std_()

Numba代码:

jitted_func = njit()(std_)
jitted_func()

业绩结果: enter image description here


Tags: 函数代码importnumpynumberfornparray
1条回答
网友
1楼 · 发布于 2024-10-06 12:39:36

您可以以矢量化的方式执行此操作

def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def std_():
    std = np.array([list_[i:i+number].std() for i in range(0, len(list_)-number)])
    return std

std1 = np.std(rolling_window(list_, 5), axis=1)
print(np.allclose(std1[:-1], std_()))

给出Truerolling_window的代码取自this答案

与numba-

import numpy as np
from numba import njit,jit,vectorize

number = 5
list_= np.random.rand(10000)

def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def std_():
    std = np.array([list_[i:i+number].std() for i in range(0, len(list_)-number)])
    return std

%timeit np.std(rolling_window(list_, 5), axis=1)
%%timeit
jitted_func = njit()(std_)
jitted_func()

给予

499 µs ± 3.98 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
106 ms ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

相关问题 更多 >