这个cython代码可以优化吗?

2024-10-02 22:34:04 发布

您现在位置:Python中文网/ 问答频道 /正文

我第一次使用cython来获得函数的速度。函数接受一个平方矩阵A浮点数并输出一个浮点数。它正在计算的函数是permanent of a matrix

enter image description here

当A是30×30时,我的代码目前在我的电脑上需要大约60秒

在下面的代码中,我从wiki页面实现了permanent的Balasubramanian Bax/franklinglynn公式。我把矩阵叫做M

代码的一个复杂部分是数组f,它用于保存数组d中下一个要翻转的位置的索引。数组d保存的值是+-1。在循环中操作f和j只是快速更新Gray代码的一种聪明方法。在

from __future__ import division
import numpy as np
cimport numpy as np
cimport cython


DTYPE_int = np.int
ctypedef np.int_t DTYPE_int_t
DTYPE_float = np.float64
ctypedef np.float64_t DTYPE_float_t

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
def permfunc(np.ndarray [DTYPE_float_t, ndim =2, mode='c'] M):
    cdef int n = M.shape[0]
    cdef np.ndarray[DTYPE_float_t, ndim =1, mode='c' ] d = np.ones(n, dtype=DTYPE_float)
    cdef int j =  0
    cdef int s = 1
    cdef np.ndarray [DTYPE_int_t, ndim =1, mode='c'] f = np.arange(n, dtype=DTYPE_int)
    cdef np.ndarray [DTYPE_float_t, ndim =1, mode='c'] v = M.sum(axis=0)
    cdef DTYPE_float_t p = 1
    cdef int i
    cdef DTYPE_float_t prod
    for i in range(n):
        p *= v[i]
    while (j < n-1):
        for i in range(n):
            v[i] -= 2*d[j]*M[j, i]
        d[j] = -d[j]
        s = -s
        prod = 1
        for i in range(n):
            prod *= v[i]
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]
    return p/2**(n-1)   

我已经使用了在cython教程中找到的所有简单优化。有些方面我不得不承认我不完全理解。例如,如果我将数组d设为int,因为值只有+-1,代码的运行速度会慢10%,所以我将其保留为float64

Is there anything else I can do to speed up the code?


这是cython-a的结果。如您所见,循环中的所有内容都被编译为C语言,因此基本的优化已经起作用。在

Result of cython -a

这是numpy中相同的函数,比我当前的cython版本慢100多倍。在

^{pr2}$

更新计时

下面是我的cython版本、numpy版本和romeric对cython代码的改进(使用ipython)。我已经为可重复性设定了种子。在

from scipy.stats import ortho_group
import pyximport; pyximport.install()
import permlib # This loads in the functions from permlib.pyx
import numpy as np; np.random.seed(7)
M = ortho_group.rvs(23) #Creates a random orthogonal matrix 
%timeit permlib.npperm(M) # The numpy version
1 loop, best of 3: 44.5 s per loop
%timeit permlib.permfunc(M) # The cython version
1 loop, best of 3: 273 ms per loop
%timeit permlib.permfunc_modified(M) #romeric's improvement
10 loops, best of 3: 198 ms per loop
M = ortho_group.rvs(28)
%timeit permlib.permfunc(M) # The cython version run on a 28x28 matrix
1 loop, best of 3: 15.8 s per loop
%timeit permlib.permfunc_modified(M) # romeric's improvement run on a 28x28 matrix
1 loop, best of 3: 12.4 s per loop

Can the cython code be sped up at all?

我使用的是gcc和CPU是amdfx8350。在


Tags: of代码importnumpyloopfornpfloat
3条回答

这个答案基于之前发布的@romeric代码。我更正了代码并简化了它,并添加了cdivision编译器指令。在

@cython.boundscheck(False) 
@cython.wraparound(False)
@cython.cdivision(True)
def permfunc_modified_2(np.ndarray [double, ndim =2, mode='c'] M):
    cdef:
        int n = M.shape[0], s=1, i, j
        int *f = <int*>malloc(n*sizeof(int))
        double *d = <double*>malloc(n*sizeof(double))
        double *v = <double*>malloc(n*sizeof(double))
        double p = 1, prod

    for i in range(n):
        v[i] = 0.
        for j in range(n):
            v[i] += M[j,i]
        p *= v[i]
        f[i] = i
        d[i] = 1
    j = 0
    while (j < n-1):
        prod = 1.
        for i in range(n):
            v[i] -= 2.*d[j]*M[j, i]
            prod *= v[i]
        d[j] = -d[j]
        s = -s            
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]

    free(d)
    free(f)
    free(v)
    return p/pow(2.,(n-1))

@romeric的原始代码没有初始化v,因此有时会得到不同的结果。另外,我分别组合了while之前的两个循环和{}内部的两个循环。在

最后,比较

^{pr2}$

免责声明:我是下面提到的工具的核心开发人员。

作为Cython的替代方案,您可以尝试一下Pythran。 原始NumPy代码的单个注释:

#pythran export npperm(float[:, :])
import numpy as np
def npperm(M):
    n = M.shape[0]
    d = np.ones(n)
    j =  0
    s = 1
    f = np.arange(n)
    v = M.sum(axis=0)
    p = np.prod(v)
    while j < n-1:
        v -= 2*d[j]*M[j]
        d[j] = -d[j]
        s = -s
        prod = np.prod(v)
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]
    return p/2**(n-1)

编制依据:

^{pr2}$

产生与Cython类似的加速:

> # numpy version
> python -mtimeit -r3 -n1 -s 'from scipy.stats import ortho_group; from perm import npperm; import numpy as np; np.random.seed(7); M = ortho_group.rvs(23)' 'npperm(M)'
1 loops, best of 3: 21.7 sec per loop
> # pythran version
> pythran perm.py
> python -mtimeit -r3 -n1 -s 'from scipy.stats import ortho_group; from perm import npperm; import numpy as np; np.random.seed(7); M = ortho_group.rvs(23)' 'npperm(M)' 
1 loops, best of 3: 171 msec per loop

不需要重新实现sum_axis(Pythran负责)。在

更有趣的是,Pythran能够通过一个选项标志识别几个可矢量化(在生成SSE/AVX内部函数的意义上)模式:

> pythran perm.py -DUSE_BOOST_SIMD -march=native
>  python -mtimeit -r3 -n10 -s 'from scipy.stats import ortho_group; from perm import npperm; import numpy as np; np.random.seed(7); M = ortho_group.rvs(23)' 'npperm(M)' 
10 loops, best of 3: 93.2 msec per loop

这使得一个最终的x232加速相对于NumPy版本,一个类似于展开的Cython版本,没有太多的手动调整。在

对于您的cython函数,您无能为力,因为它已经被很好地优化了。但是,通过完全避免调用numpy,您仍然可以获得适度的加速。在

import numpy as np
cimport numpy as np
cimport cython
from libc.stdlib cimport malloc, free
from libc.math cimport pow

cdef inline double sum_axis(double *v, double *M, int n):
    cdef:
        int i, j
    for i in range(n):
        for j in range(n):
            v[i] += M[j*n+i]


@cython.boundscheck(False) 
@cython.wraparound(False)
def permfunc_modified(np.ndarray [double, ndim =2, mode='c'] M):
    cdef:
        int n = M.shape[0], j=0, s=1, i
        int *f = <int*>malloc(n*sizeof(int))
        double *d = <double*>malloc(n*sizeof(double))
        double *v = <double*>malloc(n*sizeof(double))
        double p = 1, prod

    sum_axis(v,&M[0,0],n)

    for i in range(n):
        p *= v[i]
        f[i] = i
        d[i] = 1

    while (j < n-1):
        for i in range(n):
            v[i] -= 2.*d[j]*M[j, i]
        d[j] = -d[j]
        s = -s
        prod = 1
        for i in range(n):
            prod *= v[i]
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]

    free(d)
    free(f)
    free(v)
    return p/pow(2.,(n-1)) 

以下是必要的检查和计时:

^{pr2}$

编辑 让我们通过展开内部prod循环来执行一些基本的SSE矢量化,也就是说,将上面代码中的循环更改为以下内容

# define t1, t2 and t3 earlier as doubles
t1,t2,t3=1.,1.,1.
for i in range(0,n-1,2):
    t1 *= v[i]
    t2 *= v[i+1]
# define k earlier as int
for k in range(i+2,n):
    t3 *= v[k]
p += s*(t1*t2*t3) 

现在是时机

In [8]: %timeit permfunc_modified_vec(M) # vectorised
1 loop, best of 3: 14.0 s per loop

所以比原来优化过的cython代码快了2倍,还不错。在

相关问题 更多 >