我正在尝试优化我的python代码。当 我试图根据每个元素的值对numpy数组应用一个函数。例如,我有一个包含数千个元素的数组,我对大于公差的值应用一个函数,对其余的值应用另一个函数(泰勒级数)。我做了掩蔽,但仍然很慢,至少我调用了以下函数6400万次。在
EPSILONZETA = 1.0e-6
ZETA1_12 = 1.0/12.0
ZETA1_720 = 1.0/720.0
def masked_condition_zero(array, tolerance):
""" Return the indices where values are lesser (and greater) than tolerance
"""
# search indices where array values < tolerance
indzeros_ = np.where(np.abs(array) < tolerance)[0]
# create mask
mask_ = np.ones(np.shape(array), dtype=bool)
mask_[[indzeros_]] = False
return (~mask_, mask_)
def bernoulli_function1(zeta):
""" Returns the Bernoulli function of zeta, vector version
"""
# get the indices according to condition
zeros_, others_ = masked_condition_zero(zeta, EPSILONZETA)
# create an array filled with zeros
fb_ = np.zeros(np.shape(zeta))
# Apply the original function to the values greater than EPSILONZETA
fb_[others_] = zeta[others_]/(np.exp(zeta[others_])-1.0)
# computes series for zeta < eps
zeta0_ = zeta[zeros_]
zeta2_ = zeta0_ * zeta0_
zeta4_ = zeta2_ * zeta2_
fb_[zeros_] = 1.0 - 0.5*zeta0_ + ZETA1_12 * zeta2_ - ZETA1_720 * zeta4_
return fb_
现在假设你有一个数组zeta,它有负的和正的浮点值,它在每个循环中都会发生变化,并且每次都要计算fbernoulli函数1(zeta)。在
有更好的解决办法吗?在
问题的基本结构是:
看起来你的多项式表达式根本可以求值
zeta
,但它是一个“例外”,即当zeta
太接近0时的回退计算。在如果这两个函数都可以对
^{pr2}$zeta
求值,则可以在以下位置使用:这是精简版:
另一个选择是将一个函数应用于所有值,另一个仅应用于“异常”。在
当然还有反面-
result = func1(zeta); result[nI]=func2[zeta]
。在在我简短的时间测试中,
func1
,func2
所用的时间差不多相同。在masked_condition_zero
也需要这段时间,但是更简单的np.abs(array) < tolerance
(它是~J
)可以将这个时间减半。在让我们比较一下分配策略
对于
zeta[J]
是完整zeta
的10%的示例,某些采样时间为:第二种情况最快,因为在较少的值上运行
fun1
可以补偿索引zeta[J]
所增加的成本。在索引成本和功能评估成本之间有一个权衡。像这样的布尔索引比切片更昂贵。对于其他混合值,计时可能会朝另一个方向发展。在这看起来是一个问题,你可以削减时间,但我看不到任何突破的战略,将削减时间一个数量级。在
您可以使用numba[1](如果您使用anaconda或类似的python发行版,则应安装numba[1]),这是一个旨在使用numpy的jit编译器。在
注意:如果您使用新版本的numba,您可以将这两者合并到同一个函数中。在
在我的机器上:
^{pr2}$速度快4倍,易读。在
where命令在索引到一个数组中比较慢。这可能更快。在
编辑: 我意识到我的原始版本是复制同一个数组的两个副本。这个新版本应该快一点。在
相关问题 更多 >
编程相关推荐