scipy stats zmap函数的替代方法

2024-10-04 01:37:34 发布

您现在位置:Python中文网/ 问答频道 /正文

zmap函数的scipy stats模块是否有其他替代方案?我目前正在使用它来获得两个非常大的阵列的zmap分数,这需要相当长的时间

是否有任何库或替代方案可以提高其性能?或者甚至是另一种获得zmap函数功能的方法

您的想法和意见将不胜感激

下面是我的最小可复制代码:

from scipy import stats
import numpy as np

FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)
FeatureNorm= stats.zmap(FeatureData, goodData)

下面是scipy stats.zmap在引擎盖下的功能:

def zmap(scores, compare, axis=0, ddof=0):
    scores, compare = map(np.asanyarray, [scores, compare])
    mns = compare.mean(axis=axis, keepdims=True)
    sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
    return (scores - mns) / sstd

关于如何针对我的用例优化它,有什么想法吗?我可以使用像numba或JAX这样的库来进一步增强这一点吗


Tags: 函数import功能statsnp方案randomscipy
1条回答
网友
1楼 · 发布于 2024-10-04 01:37:34

幸运的是,zmap代码非常简单。然而,numpy的开销将来自它必须实例化中间数组这一事实。如果使用numbajax中提供的数值编译器,它可以融合这些操作并以较少的开销进行计算

不幸的是,NUBA不支持对{{CD4}}和^ {CD5>}的可选参数,所以让我们看看JAX。以下是在Google Colab CPU运行时计算的scipy和原始numpy版本函数的基准,供参考:

import numpy as np
from scipy import stats

FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)

%timeit stats.zmap(FeatureData, goodData)
# 100 loops, best of 3: 13.9 ms per loop

def np_zmap(scores, compare, axis=0, ddof=0):
    scores, compare = map(np.asanyarray, [scores, compare])
    mns = compare.mean(axis=axis, keepdims=True)
    sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
    return (scores - mns) / sstd

%timeit np_zmap(FeatureData, goodData)
# 100 loops, best of 3: 13.8 ms per loop

以下是在JAX中执行的等效代码,包括急切模式和JIT编译:

import jax.numpy as jnp
from jax import jit

def jnp_zmap(scores, compare, axis=0, ddof=0):
    scores, compare = map(jnp.asarray, [scores, compare])
    mns = compare.mean(axis=axis, keepdims=True)
    sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
    return (scores - mns) / sstd

jit_jnp_zmap = jit(jnp_zmap)

FeatureData = jnp.array(FeatureData)
goodData = jnp.array(goodData)
%timeit jnp_zmap(FeatureData, goodData).block_until_ready()
# 100 loops, best of 3: 8.59 ms per loop

jit_jnp_zmap(FeatureData, goodData)  # trigger compilation
%timeit jit_jnp_zmap(FeatureData, goodData).block_until_ready()
# 100 loops, best of 3: 2.78 ms per loop

JIT编译版本大约比scipy或numpy代码快5倍。在Colab T4 GPU运行时上,编译版本获得另一个因子10:

%timeit jit_jnp_zmap(FeatureData, goodData).block_until_ready()
1000 loops, best of 3: 286 µs per loop

如果这种操作是分析中的瓶颈,那么像JAX这样的编译器可能是一个不错的选择

相关问题 更多 >