我有218k+33通道图像,我需要找到每个通道的平均值和方差。我试过使用多重处理,但这似乎慢得难以忍受。下面是一个简短的代码示例:
def work(aggregates, genput):
# received (channel, image) from generator
channel = genput[0]
image = genput[1]
for row in image:
for pixel in row:
# use welford's to update a list of "aggregates" which will
# later be finalized as means and variances of each channel
aggregates[channel] = update(aggregates[channel], pixel)
def data_stream(df, data_root):
'''Generator that returns the channel and image for each tif file'''
for index, sample in df.iterrows():
curr_img_path = data_root
# read the image with all channels
tif = imread(curr_img_path) #33x64x64 array
for channel, image in enumerate(tif):
yield (channel, image)
# Pass over each image, compute mean/variance for each channel for each image
def preprocess_mv(df, data_root, channels=33, multiprocessing=True):
'''Calculates mean and variance on the whole image set for use in deep_learn'''
manager = Manager()
aggregates = manager.list()
[aggregates.append(([0,0,0])) for i in range(channels)]
proxy = partial(work, aggregates)
pool = Pool(processes=8)
pool.imap(proxy, data_stream(df, data_root), chunksize=5000)
pool.close()
pool.join()
# finalize data below
我的怀疑是,pickle aggregates
数组和从父进程到子进程来回传输所需的时间非常长,这是主要的瓶颈——我可以看到这个缺点完全消除了多进程的优势,因为每个子进程都必须等待其他子进程进行pickle以及解开数据。我读到这是多处理库的一个局限性,从我在这里读到的其他文章中,我意识到这可能是我能做的最好的了。也就是说,有人对如何改进这一点有什么建议吗?你知道吗
另外,我想知道是否有更好的库/工具来完成这个任务?一个朋友实际上推荐了Scala,我也一直在研究这个选项。我只是非常熟悉Python,如果可能的话,我想留在这个领域。你知道吗
我通过对
multiprocessing.Array
进行更深入的探索,找到了一个解决方案。我必须弄清楚如何将我的二维数组转换成一维数组,并且仍然能够建立索引,但这最终是一个非常简单的数学问题。我现在可以在2分钟内处理1000个样品,而不是4小时,所以我觉得这很好。我还必须编写一个自定义函数来打印数组,但这相当简单。这个实现并不能保证不受竞争条件的影响,但就我的目的而言,它工作得相当好。通过将锁包含在init中并以与数组相同的方式传递它(使用global
),可以很容易地添加锁。你知道吗相关问题 更多 >
编程相关推荐