我第一次使用cython来获得函数的速度。函数接受一个平方矩阵A
浮点数并输出一个浮点数。它正在计算的函数是permanent of a matrix
当A是30×30时,我的代码目前在我的电脑上需要大约60秒
在下面的代码中,我从wiki页面实现了permanent的Balasubramanian Bax/franklinglynn公式。我把矩阵叫做M
代码的一个复杂部分是数组f,它用于保存数组d中下一个要翻转的位置的索引。数组d保存的值是+-1。在循环中操作f和j只是快速更新Gray代码的一种聪明方法。在
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE_int = np.int
ctypedef np.int_t DTYPE_int_t
DTYPE_float = np.float64
ctypedef np.float64_t DTYPE_float_t
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
def permfunc(np.ndarray [DTYPE_float_t, ndim =2, mode='c'] M):
cdef int n = M.shape[0]
cdef np.ndarray[DTYPE_float_t, ndim =1, mode='c' ] d = np.ones(n, dtype=DTYPE_float)
cdef int j = 0
cdef int s = 1
cdef np.ndarray [DTYPE_int_t, ndim =1, mode='c'] f = np.arange(n, dtype=DTYPE_int)
cdef np.ndarray [DTYPE_float_t, ndim =1, mode='c'] v = M.sum(axis=0)
cdef DTYPE_float_t p = 1
cdef int i
cdef DTYPE_float_t prod
for i in range(n):
p *= v[i]
while (j < n-1):
for i in range(n):
v[i] -= 2*d[j]*M[j, i]
d[j] = -d[j]
s = -s
prod = 1
for i in range(n):
prod *= v[i]
p += s*prod
f[0] = 0
f[j] = f[j+1]
f[j+1] = j+1
j = f[0]
return p/2**(n-1)
我已经使用了在cython教程中找到的所有简单优化。有些方面我不得不承认我不完全理解。例如,如果我将数组d
设为int,因为值只有+-1,代码的运行速度会慢10%,所以我将其保留为float64
Is there anything else I can do to speed up the code?
这是cython-a的结果。如您所见,循环中的所有内容都被编译为C语言,因此基本的优化已经起作用。在
这是numpy中相同的函数,比我当前的cython版本慢100多倍。在
^{pr2}$更新计时
下面是我的cython版本、numpy版本和romeric对cython代码的改进(使用ipython)。我已经为可重复性设定了种子。在
from scipy.stats import ortho_group
import pyximport; pyximport.install()
import permlib # This loads in the functions from permlib.pyx
import numpy as np; np.random.seed(7)
M = ortho_group.rvs(23) #Creates a random orthogonal matrix
%timeit permlib.npperm(M) # The numpy version
1 loop, best of 3: 44.5 s per loop
%timeit permlib.permfunc(M) # The cython version
1 loop, best of 3: 273 ms per loop
%timeit permlib.permfunc_modified(M) #romeric's improvement
10 loops, best of 3: 198 ms per loop
M = ortho_group.rvs(28)
%timeit permlib.permfunc(M) # The cython version run on a 28x28 matrix
1 loop, best of 3: 15.8 s per loop
%timeit permlib.permfunc_modified(M) # romeric's improvement run on a 28x28 matrix
1 loop, best of 3: 12.4 s per loop
Can the cython code be sped up at all?
我使用的是gcc和CPU是amdfx8350。在
这个答案基于之前发布的@romeric代码。我更正了代码并简化了它,并添加了
cdivision
编译器指令。在@romeric的原始代码没有初始化}内部的两个循环。在
v
,因此有时会得到不同的结果。另外,我分别组合了while
之前的两个循环和{最后,比较
^{pr2}$免责声明:我是下面提到的工具的核心开发人员。
作为Cython的替代方案,您可以尝试一下Pythran。 原始NumPy代码的单个注释:
编制依据:
^{pr2}$产生与Cython类似的加速:
不需要重新实现
sum_axis
(Pythran负责)。在更有趣的是,Pythran能够通过一个选项标志识别几个可矢量化(在生成SSE/AVX内部函数的意义上)模式:
这使得一个最终的x232加速相对于NumPy版本,一个类似于展开的Cython版本,没有太多的手动调整。在
对于您的
cython
函数,您无能为力,因为它已经被很好地优化了。但是,通过完全避免调用numpy
,您仍然可以获得适度的加速。在以下是必要的检查和计时:
^{pr2}$编辑 让我们通过展开内部
prod
循环来执行一些基本的SSE
矢量化,也就是说,将上面代码中的循环更改为以下内容现在是时机
所以比原来优化过的cython代码快了2倍,还不错。在
相关问题 更多 >
编程相关推荐