加速Python中的集成函数

2024-10-02 20:36:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个函数,它是某个更大问题的内环。所以它将被称为百万时间。我试着优化它。但由于这是我的第一个数值项目,我想知道是否有其他方法可以提高速度。在

cython似乎帮不上忙。也许numpy已经接近c了。 或者我不能高效地编写cython代码。在

import numpy as np
import math
import numexpr as ne


par_mu_rho = 0.8
par_alpha_rho = 0.7
# ' the first two are mean of mus and the '
# ' last two are the mean of alphas.'
cov_epsilon = [[1, par_mu_rho], [par_mu_rho, 1]]
cov_nu = [[1, par_alpha_rho], [par_alpha_rho, 1]]
nrows = 10000 
np.random.seed(123)
epsilon_sim = np.random.multivariate_normal([0, 0], cov_epsilon, nrows)
nu_sim = np.random.multivariate_normal([0, 0], cov_nu, nrows)
errors = np.concatenate((epsilon_sim, nu_sim), axis=1)
errors = np.exp(errors)


### the function to be optimized

def mktout(mean_mu_alpha, errors, par_gamma):
    mu10 = errors[:, 0] * math.exp(mean_mu_alpha[0])
    mu11 = math.exp(par_gamma) * mu10  # mu with gamma
    mu20 = errors[:, 1] * math.exp(mean_mu_alpha[1])
    mu21 = math.exp(par_gamma) * mu20
    alpha1 = errors[:, 2] * math.exp(mean_mu_alpha[2])
    alpha2 = errors[:, 3] * math.exp(mean_mu_alpha[3])

    j_is_larger = (mu10 > mu20)
    #     useneither1 = (mu10 < 1/168)
    threshold2 = (1 + mu10 * alpha1) / (168 + alpha1)
    #     useboth1 = (mu21 >= threshold2)
    j_is_smaller = ~j_is_larger
    #     useneither2 = (mu20 < 1/168)
    threshold3 = (1 + mu20 * alpha2) / (168 + alpha2)
    #     useboth2 = (mu11 >= threshold3)
    case1 = j_is_larger * (mu10 < 1 / 168)
    case2 = j_is_larger * (mu21 >= threshold2)
    #     case3 = j_is_larger * (1 - (useneither1 | useboth1))
    case3 = j_is_larger ^ (case1 | case2)
    case4 = j_is_smaller * (mu20 < 1 / 168)
    case5 = j_is_smaller * (mu11 >= threshold3)
    #     case6 = j_is_smaller * (1 - (useneither2 | useboth2))
    case6 = j_is_smaller ^ (case4 | case5)
    t0 = ne.evaluate(
        "case1*168+case2 * (168 + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) +case3 / threshold2 +case4 * 168 +case5 * (168 + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case6 / threshold3")
    # for some cases, t1 would be 0 anyway, so they are omitted here.
    t1 = ne.evaluate(
        "case2 * (t0 * alpha1 * mu11 - alpha1) +case3 * (t0 * alpha1 * mu10 - alpha1) +case5 * (t0 * alpha1 * mu11 - alpha1)")
    # t2 = (j_is_larger*useboth1*(t0*alpha2*mu21- alpha2) +
    #       j_is_smaller*useboth2*(t0*alpha2*mu21- alpha2) +
    #       j_is_smaller*(1- (useneither2|useboth2))*(t0*alpha2*mu20 - alpha2)
    #       )
    t2 = 168 - t0 - t1
    p12 = case2 + case5
    p1 = case3 + p12
    p2 = case6 + p12
    return t1.sum()/10000, t2.sum()/10000, p1.sum()/10000, p2.sum()/10000

timeit mktout([-6,-6,-1,-1], errors, -0.7)

在我的旧mac电脑上用2.2GHz i7。该函数的运行速度约为200微秒

更新:

根据@coderegram和@GZ0的建议和代码,我决定使用以下代码

^{pr2}$

我的原始代码运行速度为650微秒。 mktoutmktout_if由代码外科医生在大约220µs和120µs下运行。 上述mktout_full的运行速度约为68µs。 我在mktout_full中所做的是优化mktout_if中的if-else逻辑。 也许令人惊讶的是,代码外科医生将out_loop与{}中的if-else逻辑相结合的并行化要慢得多(121ms)。在


Tags: alphaismathmeanerrorsparalpha2mu
1条回答
网友
1楼 · 发布于 2024-10-02 20:36:12

简单地看一下代码并尝试对其进行同步化,只需将ndarray类型添加到所有参数和变量中并不会对性能产生有意义的更改。如果您在这个紧凑的内部循环中为这个函数节省微秒,我将考虑进行以下修改:

  1. 这段代码之所以如此难以cythonize是因为您的代码是矢量化的。所有操作都要经过numpynumexpr虽然这些操作中的每一个都是高效的,但它们都会增加一些python开销(如果您查看cython可以生成的带注释的.html文件,可以看到这一点)。在
  2. 如果您多次调用此函数(根据您的注释显示),可以将mktout改为cdef函数来节省一些时间。Python函数调用有很大的开销。
  3. 次要的,但是您可以尝试避免python的math模块中的任何函数。您可以用from libc cimport math as cmath替换它,并使用cmath.exp。在
  4. 我看到您的mktout函数接受一个python列表mean_mu_alpha。您可以考虑使用一个cdef class对象来替换这个参数,然后键入这个。如果您选择将mktout改为cdef函数,那么它可以变成一个结构或double *数组。无论哪种方式,索引到python列表(可以包含任意python对象,这些对象需要解包到相应的c类型中)的速度都会很慢。在
  5. 这可能是最重要的部分。对于对mktout的每次调用,您都在为许多数组分配内存(对于每个mualphathresholdcaset-和{}数组)。然后,在函数末尾(通过python的gc)释放所有这些内存,只可能在下一次调用时再次使用所有这些空间。如果您可以更改mktout的签名,则可以将所有这些数组作为参数传入,以便内存可以在函数调用之间重用和覆盖。另一个更适合这种情况的方法是遍历数组,一次只计算一个元素。
  6. 您可以使用cython的prange函数对代码进行多线程处理。在您完成以上所有更改之后,我将进行此操作,并在mktout函数本身之外执行多线程处理。也就是说,将多线程调用mktout而不是多线程mktout本身。在

进行上述更改需要大量的工作,而且您可能需要自己重新实现numpy和numexpr提供的许多函数,以避免与每次修改相关的python开销。如果有不清楚的地方请告诉我。在


更新#1:实现点#1、#3和#5,我得到了一个11倍的加速。下面是这个代码的样子。我确信,如果您放弃def函数、list mean_mu_alpha输入和tuple输出,它会更快。注意:与原始代码相比,我得到的最后一个小数位的结果略有不同,可能是因为我不理解某些浮点规则。

from libc cimport math as cmath
from libc.stdint cimport *
from libc.stdlib cimport *

def mktout(list mean_mu_alpha, double[:, ::1] errors, double par_gamma):
    cdef:
        size_t i, n
        double[4] exp
        double exp_par_gamma
        double mu10, mu11, mu20, mu21
        double alpha1, alpha2
        bint j_is_larger, j_is_smaller
        double threshold2, threshold3
        bint case1, case2, case3, case4, case5, case6
        double t0, t1, t2
        double p12, p1, p2
        double t1_sum, t2_sum, p1_sum, p2_sum
        double c

    #compute the exp outside of the loop
    n = errors.shape[0]
    exp[0] = cmath.exp(<double>mean_mu_alpha[0])
    exp[1] = cmath.exp(<double>mean_mu_alpha[1])
    exp[2] = cmath.exp(<double>mean_mu_alpha[2])
    exp[3] = cmath.exp(<double>mean_mu_alpha[3])
    exp_par_gamma = cmath.exp(par_gamma)
    c = 168.0

    t1_sum = 0.0
    t2_sum = 0.0
    p1_sum = 0.0
    p2_sum = 0.0

    for i in range(n):
        mu10 = errors[i, 0] * exp[0]
        mu11 = exp_par_gamma * mu10
        mu20 = errors[i, 1] * exp[1]
        mu21 = exp_par_gamma * mu20
        alpha1 = errors[i, 2] * exp[2]
        alpha2 = errors[i, 3] * exp[3]

        j_is_larger = mu10 > mu20
        j_is_smaller = not j_is_larger
        threshold2 = (1 + mu10 * alpha1) / (c + alpha1)
        threshold3 = (1 + mu20 * alpha2) / (c + alpha2)

        case1 = j_is_larger * (mu10 < 1 / c)
        case2 = j_is_larger * (mu21 >= threshold2)
        case3 = j_is_larger ^ (case1 | case2)
        case4 = j_is_smaller * (mu20 < 1 / c)
        case5 = j_is_smaller * (mu11 >= threshold3)
        case6 = j_is_smaller ^ (case4 | case5)

        t0 = case1*c+case2 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) +case3 / threshold2 +case4 * c +case5 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case6 / threshold3
        t1 = case2 * (t0 * alpha1 * mu11 - alpha1) +case3 * (t0 * alpha1 * mu10 - alpha1) +case5 * (t0 * alpha1 * mu11 - alpha1)
        t2 = c - t0 - t1

        p12 = case2 + case5
        p1 = case3 + p12
        p2 = case6 + p12

        t1_sum += t1
        t2_sum += t2
        p1_sum += p1
        p2_sum += p2

    return t1_sum/n, t2_sum/n, p1_sum/n, p2_sum/n

更新2:实现了cdef(#2)、python对象消除(#4)和多线程(#6)思想。#单独使用2和4的好处微乎其微,但对6是必要的,因为GIL不能在OpenMP prange循环中访问。有了多线程,我的四核笔记本电脑的速度提高了2.5倍,相当于代码比原来快了27.5倍。虽然我的outer_loop函数并不完全准确,因为它只是反复地重新计算相同的结果,但是对于一个测试用例来说,它应该足够了。完整代码如下:

^{pr2}$

我使用的setup.py文件如下(包含所有优化和OpenMP标志):

from distutils.core import setup
from Cython.Build import cythonize
from distutils.core import Extension
import numpy as np
import os
import shutil
import platform

libraries = {
    "Linux": [],
    "Windows": [],
}
language = "c"
args = ["-w", "-std=c11", "-O3", "-ffast-math", "-march=native", "-fopenmp"]
link_args = ["-std=c11", "-fopenmp"]

annotate = True
directives = {
    "binding": True,
    "boundscheck": False,
    "wraparound": False,
    "initializedcheck": False,
    "cdivision": True,
    "nonecheck": False,
    "language_level": "3",
    #"c_string_type": "unicode",
    #"c_string_encoding": "utf-8",
}

if __name__ == "__main__":
    system = platform.system()
    libs = libraries[system]
    extensions = []
    ext_modules = []

    #create extensions
    for path, dirs, file_names in os.walk("."):
        for file_name in file_names:
            if file_name.endswith("pyx"):
                ext_path = "{0}/{1}".format(path, file_name)
                ext_name = ext_path \
                    .replace("./", "") \
                    .replace("/", ".") \
                    .replace(".pyx", "")
                ext = Extension(
                    name=ext_name, 
                    sources=[ext_path], 
                    libraries=libs,
                    language=language,
                    extra_compile_args=args,
                    extra_link_args=link_args,
                    include_dirs = [np.get_include()],
                )
                extensions.append(ext)

    #setup all extensions
    ext_modules = cythonize(
        extensions, 
        annotate=annotate, 
        compiler_directives=directives,
    )
    setup(ext_modules=ext_modules)

    """
    #immediately remove build directory
    build_dir = "./build"
    if os.path.exists(build_dir):
        shutil.rmtree(build_dir)
    """
<小时/>

Update#3:根据@GZ0的建议,有很多情况下代码中的表达式的计算结果将为零,并且会浪费计算。我尝试用以下代码消除这些区域(在修复了case3case6语句之后):

cdef void cy_mktout_if(Vec4 *out, Vec4 *mean_mu_alpha, double[:, ::1] errors, double par_gamma) nogil:
    cdef:
        size_t i, n
        double[4] exp
        double exp_par_gamma
        double mu10, mu11, mu20, mu21
        double alpha1, alpha2
        bint j_is_larger
        double threshold2, threshold3
        bint case1, case2, case3, case4, case5, case6
        double t0, t1, t2
        double p12, p1, p2
        double t1_sum, t2_sum, p1_sum, p2_sum
        double c

    #compute the exp outside of the loop
    n = errors.shape[0]
    exp[0] = cmath.exp(mean_mu_alpha.a)
    exp[1] = cmath.exp(mean_mu_alpha.b)
    exp[2] = cmath.exp(mean_mu_alpha.c)
    exp[3] = cmath.exp(mean_mu_alpha.d)
    exp_par_gamma = cmath.exp(par_gamma)
    c = 168.0

    t1_sum = 0.0
    t2_sum = 0.0
    p1_sum = 0.0
    p2_sum = 0.0

    for i in range(n):
        mu10 = errors[i, 0] * exp[0]
        mu11 = exp_par_gamma * mu10
        mu20 = errors[i, 1] * exp[1]
        mu21 = exp_par_gamma * mu20
        alpha1 = errors[i, 2] * exp[2]
        alpha2 = errors[i, 3] * exp[3]

        j_is_larger = mu10 > mu20
        j_is_smaller = not j_is_larger
        threshold2 = (1 + mu10 * alpha1) / (c + alpha1)
        threshold3 = (1 + mu20 * alpha2) / (c + alpha2)

        if j_is_larger:
            case1 = mu10 < 1 / c
            case2 = mu21 >= threshold2
            case3 = not (case1 | case2)

            t0 = case1*c + case2 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case3 / threshold2
            t1 = case2 * (t0 * alpha1 * mu11 - alpha1) + case3 * (t0 * alpha1 * mu10 - alpha1)
            t2 = c - t0 - t1

            t1_sum += t1
            t2_sum += t2
            p1_sum += case2 + case3
            p2_sum += case2

        else:
            case4 = mu20 < 1 / c
            case5 = mu11 >= threshold3
            case6 = not (case4 | case5)

            t0 = case4 * c + case5 * (c + alpha1 + alpha2) / (1 + mu11 * alpha1 + mu21 * alpha2) + case6 / threshold3
            t1 = case5 * (t0 * alpha1 * mu11 - alpha1)
            t2 = c - t0 - t1

            t1_sum += t1
            t2_sum += t2
            p1_sum += case5
            p2_sum += case5 + case6

    out.a = t1_sum/n
    out.b = t2_sum/n
    out.c = p1_sum/n
    out.d = p2_sum/n

对于10000次迭代,当前代码执行如下所示:

outer_loop: 0.5116949229995953 seconds
outer_loop_if: 0.617649456995423 seconds
mktout: 0.9221872320049442 seconds
mktout_if: 1.430276553001022 seconds
python: 10.116664300003322 seconds

我认为条件和分支预测失误的代价使函数的运行速度慢得惊人,但我希望任何人能帮助我澄清这一点。在

相关问题 更多 >