用Cython编译Scipy函数

2024-10-16 20:46:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我在Python3.4中运行一个模拟——它涉及稀疏数组(csr格式)和密集向量之间的大量点积。我用Scipy表示稀疏矩阵,numpy表示其他所有的矩阵。在

使用Cython给了我一个巨大的提升(~x6的速度提高),在确保我正确地使用cdef之后,并且最小化了Python的交互(bt检查Cython给我的html文件并修改了我的代码)。在

现在,我分析了代码,50%的模拟时间都花在了点积上。我想知道是否有可能以某种方式加速这条线,比如在Cython中编译这个单点函数?在

我知道我可以为(csr-sprase 2d matrix)dot(稠密向量)编写自己的实现,但我正试图避免这种情况。在

编辑:我包含了一个最小的代码示例。对不起,我看不出怎么把它缩小。这是统计力学的教科书练习。把大理石放在罐子里直到其中一个罐子超过容量。然后,启动一个级联,它根据一个(这里是稀疏的)矩阵传播。我用的是分批抽样。在

请把注意力集中在队伍的尽头。在

from __future__ import division
import  numpy           as np
import  cython
cimport numpy           as np
cimport cpython.array
import  scipy.sparse    as sps


@cython.cdivision(True)
@cython.nonecheck(False)
@cython.boundscheck(False)
@cython.wraparound(False)
def simulate(long[:] capacity_vec,
             int random_array_size,
             long n,
             int seed,
             int[:] A_col,
             int[:] A_row,
             long[:] A_data):


    #### Initialise ####################################################################################################

    # Initialise states
    cdef int[:] damage  = np.random.randint(0, int(np.min(capacity_vec)/2), n).astype(np.int32)
    cdef int[:] dr_list = np.random.choice(n, 1000).astype(np.int32)
    cdef int[:] states  = np.zeros(n).astype(np.int32)
    cdef int[:] states_ = np.zeros(n).astype(np.int32)
    cdef int[:] change  = np.zeros(n).astype(np.int32)

    # Initialise counters
    cdef int k, violations, violations_, counter= 0, dr_id=0, increment_index = 0


    # Build Sparse Adjecency Matrix
    cA_sps = sps.csr_matrix( (A_data, (A_row, A_col) ), shape=(n,n) ).astype(np.int32)


    while counter < 1000:

        #### Place damage until a cascade starts #######################################################################
        while damage[increment_index] <= capacity_vec[increment_index]:# Check for violations

            increment_index         = dr_list[dr_id]                   # Where do we place the marble?

            damage[increment_index] = damage[increment_index] + 1      # place the marble

            dr_id                   = dr_id + 1                        # another random number used

            if dr_id == random_array_size - 1:                         # Check if we run out of random numbers

                dr_list = np.random.choice(n, random_array_size).astype(np.int32) # if so, pick new increment_index

                dr_id   = 0                                            # and reset the counter


        #### Initialise cascade ########################################################################################
        violations, violations_  = 1, 0
        states[increment_index]  = 1


        #### Propagate cascade #########################################################################################
        while violations > violations_:                                # check for fixed point, propagate cascade
            for k in range(n): change[k] = states[k] - states_[k]
            ### THIS LINE IS THE PROBLEM. It takes up half of all simulation time.
            damage      = damage + cA_sps.dot(change).astype(np.int32) # spread violations  

            states_     = states.copy()                                # store previous states

            # Determine previous and current violations
            violations, violations_ = 0 , violations

            for k in range(n):

                states_[k]  = 0

                if damage[k] > capacity_vec[k]:

                    violations = violations + 1

                    states[k]  =  1                                    # deactivate any node that has a violation


        for k in range(n): damage[k] = 0
        counter  = counter + 1                                         # progress cascade id after storing

Tags: idindexnpcounterrandomcythonintdamage
1条回答
网友
1楼 · 发布于 2024-10-16 20:46:16

我不鼓励你自己写矩阵乘法。SciPy是由知道自己在做什么的聪明人来完成的,除非你对数值计算有信心,否则不要这么做。在

但是,您可能会看到sparse.csr_matrix.dot的代码。直接进入定义here然后here,您将看到在Scipy中很少进行检查。如果你知道你想要什么样的格式,你可以写你自己的方法(修改你的SciPy拷贝)并直接计算你的产品。不过,不知道这会有多大帮助。在

若您想自己构建Scipy,只需从GitHug签出整个项目,然后运行即可

python setup.py build
python setup.py install

有关更直接的说明,请检查build documentation。在

相关问题 更多 >