在python中使用多重处理折叠数组突变

2024-09-30 02:21:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我有218k+33通道图像,我需要找到每个通道的平均值和方差。我试过使用多重处理,但这似乎慢得难以忍受。下面是一个简短的代码示例:

def work(aggregates, genput):
    # received (channel, image) from generator
    channel = genput[0]
    image = genput[1]
    for row in image:
        for pixel in row:
            # use welford's to update a list of "aggregates" which will
            # later be finalized as means and variances of each channel
            aggregates[channel] = update(aggregates[channel], pixel)

def data_stream(df, data_root):
    '''Generator that returns the channel and image for each tif file'''
    for index, sample in df.iterrows():
        curr_img_path = data_root

        # read the image with all channels
        tif = imread(curr_img_path)  #33x64x64 array        
        for channel, image in enumerate(tif):
            yield (channel, image)     

# Pass over each image, compute mean/variance for each channel for each image
def preprocess_mv(df, data_root, channels=33, multiprocessing=True):
    '''Calculates mean and variance on the whole image set for use in deep_learn'''
    manager = Manager()
    aggregates = manager.list()

    [aggregates.append(([0,0,0])) for i in range(channels)]

    proxy = partial(work, aggregates)

    pool = Pool(processes=8) 
    pool.imap(proxy, data_stream(df, data_root), chunksize=5000)
    pool.close()
    pool.join()

    # finalize data below

我的怀疑是,pickle aggregates数组和从父进程到子进程来回传输所需的时间非常长,这是主要的瓶颈——我可以看到这个缺点完全消除了多进程的优势,因为每个子进程都必须等待其他子进程进行pickle以及解开数据。我读到这是多处理库的一个局限性,从我在这里读到的其他文章中,我意识到这可能是我能做的最好的了。也就是说,有人对如何改进这一点有什么建议吗?你知道吗

另外,我想知道是否有更好的库/工具来完成这个任务?一个朋友实际上推荐了Scala,我也一直在研究这个选项。我只是非常熟悉Python,如果可能的话,我想留在这个领域。你知道吗


Tags: andtheinimagedffordata进程
1条回答
网友
1楼 · 发布于 2024-09-30 02:21:08

我通过对multiprocessing.Array进行更深入的探索,找到了一个解决方案。我必须弄清楚如何将我的二维数组转换成一维数组,并且仍然能够建立索引,但这最终是一个非常简单的数学问题。我现在可以在2分钟内处理1000个样品,而不是4小时,所以我觉得这很好。我还必须编写一个自定义函数来打印数组,但这相当简单。这个实现并不能保证不受竞争条件的影响,但就我的目的而言,它工作得相当好。通过将锁包含在init中并以与数组相同的方式传递它(使用global),可以很容易地添加锁。你知道吗

def init(arr):
    global aggregates
    aggregates = arr

def work(genput):
    # received (sample, channel, image) from generator
    sample_no = genput[0]
    channel = genput[1]
    image = genput[2]
    currAgg =  (aggregates[3*channel], aggregates[3*channel+1], 
                aggregates[3*channel+2])
    for row in image:
        for pixel in row:
            # use welford's to compute updated aggregate
            newAgg = update(currAgg, pixel)
            currAgg = newAgg
    # New method of indexing for 1D array ("shaped" as 33x3)
    aggregates[3*channel] = newAgg[0]
    aggregates[(3*channel)+1] = newAgg[1]
    aggregates[(3*channel)+2] = newAgg[2]

def data_stream(df, data_root):
    '''Generator that returns the channel and image for each tif file'''
    ...
    yield (index, channel, image)


if __name__ == '__main__':

    aggs = Array('d', np.zeros(99)) #99 values for all aggrs

    pool = Pool(initializer=init, initargs=(aggs,), processes=8)
    pool.imap(work, data_stream(df, data_root), chunksize=10)
    pool.close()
    pool.join()

#     -finalize aggregates below

相关问题 更多 >

    热门问题