NumPy中的多处理

2024-10-01 02:19:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我在网上寻找解决问题的方法,但没有找到任何对我有帮助的东西。 我的问题是我希望通过实现多处理来加速我的程序。函数getSVJJPrice相当快。但是,K的大小大约是1000,这使得整个代码相当慢。因此,我想知道是否有任何方法可以并行化for循环?代码如下所示。在

def func2min(x,S,expiry,K,r,prices,curr):
    bid = prices[:,0]
    ask = prices[:,1]

    C_omega = [0]*len(K)
    w = [0]*len(K)

    for ind, k in enumerate(K):
        w[ind] = 1/np.abs(bid[ind] - ask[ind])
        C_omega[ind] = getSVJJPrice(x[0],(x[1] + x[0]**2)/(2*x[2]),
        x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],S[ind],k,r[ind],expiry[ind],
                curr[ind])  

    right = np.sum(w * (np.subtract(C_omega, np.mean(prices,axis=1)))**2)

    print right
    #if right < 10:    
    #    print '\n \n func = ', right 

    if math.isnan(right):
        right = 1e12

    return right

万分感谢所有调查此事的人!在

谨致问候

维克托


Tags: 方法代码rightforlennpaskprices
1条回答
网友
1楼 · 发布于 2024-10-01 02:19:57

似乎multiprocessing.Pool可能适合您的情况,因为您正在K中的每个元素循环,K似乎只是代码中的一个一维数组。在

基本上,你首先要写一个执行循环的函数,在我的例子parallel_loop中,然后你必须把你的问题分成几个独立的块,在这个例子中,你只需要把K分割成一个整型的nprocs。在

然后可以使用pool.map对每个块并行执行循环,结果将按块的顺序收集回来,这些块的顺序与原始的K相同,因为我们没有重新排列任何内容,只是执行了切片。然后你只需要把所有的部分放回w和{}。在

import numpy as np
from multiprocessing import Pool

def parallel_loop(K_chunk):
    C_omega_chunk = np.empty(len(K_chunk)
    w_chunk = np.empty(len(K_chunk))

    for ind, k in enumerate(K_chunk)
        w_chunk[ind] = 1/np.abs(bid[ind] - ask[ind])
        C_omega_chunk[ind] = getSVJJPrice(x[0],(x[1] + x[0]**2)/(2*x[2]),
        x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],S[ind],k,r[ind],expiry[ind],
                curr[ind])  

    return (w_chunk, C_omega_chunk)

def func2min(x,S,expiry,K,r,prices,curr,nprocs):
    bid = prices[:,0]
    ask = prices[:,1]

    K = np.array(K)

    K_chunks = [K[n * len(K) // nprocs : (n + 1) * len(K) // nprocs] for n in range(nprocs)]
    pool = Pool(processes=nprocs)  
    outputs = pool.map(parallel_loop, K_chunks)

    w, C_omega = (np.concatenate(var) for var in zip(*outputs))

    right = np.sum(w * (np.subtract(C_omega, np.mean(prices,axis=1)))**2)

    print right
    #if right < 10:    
    #    print '\n \n func = ', right 

    if math.isnan(right):
        right = 1e12

    return right

因为我没有一个示例数据集,所以我不能确定上面的示例是否能按原样工作,但我认为它应该能让您大致了解它是如何工作的。在

相关问题 更多 >