基于numbybased的递归函数撤消最大池的瓶颈

2024-10-04 05:24:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我设计了一个递归函数来处理深度学习社区中的特定问题。在大多数情况下,它似乎工作迅速,效果良好,但在其他情况下,却需要约20分钟,似乎没有任何理由。在最简单的情况下,这个函数可以抽象为简单的numpy在两个轴上的“repeat”函数。下面是我用来测试这个函数的代码:

def recursive_upsample(fMap, index, dims):
    if index == 0:
        return fMap
    else:
        start = time.time()
        upscale = np.zeros((dims[index-1][0],dims[index-1][1],fMap.shape[-1]))
        if dims[index-1][0] % 2 == 1 and dims[index-1][1] % 2 == 1:
            crop = fMap[:fMap.shape[0]-1,:fMap.shape[1]-1]
            consX = fMap[-1,:][:-1]
            consY = fMap[:,-1][:-1]
            corner = fMap[-1,-1]
            crop = crop.repeat(2, axis=0).repeat(2, axis=1)
            upscale[:crop.shape[0],:crop.shape[1]] = crop
            upscale[-1,:][:-1] = consX.repeat(2,axis=0)
            upscale[:,-1][:-1] = consY.repeat(2,axis=0)
            upscale[-1,-1] = corner

        elif dims[index-1][0] % 2 == 1:
            crop = fMap[:fMap.shape[0]-1]
            consX = fMap[-1:,]
            crop = crop.repeat(2, axis=0).repeat(2, axis=1)
            upscale[:crop.shape[0]] = crop
            upscale[-1:,] = consX.repeat(2,axis=1)

        elif dims[index-1][1] % 2 == 1:
            crop = fMap[:,:fMap.shape[1]-1]
            consY = fMap[:,-1]
            crop = crop.repeat(2, axis=0).repeat(2, axis=1)
            upscale[:,:crop.shape[1]] = crop
            upscale[:,-1] = consY.repeat(2,axis=0)


        else:
            upscale = fMap.repeat(2, axis=0).repeat(2, axis=1)

        print('Upscaling from {} to {} took {} seconds'.format(fMap.shape,upscale.shape,time.time() - start))
        fMap = upscale

        return recursive_upsample(fMap,index-1,dims)

if __name__ == '__main__':
    dims = [(634,1020,64),(317,510,128),(159,255,256),(80,128,512),(40,64,512)]
    images = []
    for dim in dims:
        image = np.random.rand(dim[0],dim[1],dim[2])
        images.append(image)
    start = time.time()
    upsampled = []
    for index,image in enumerate(images):
        upsampled.append(recursive_upsample(image,index,dims))
    print('Upsampling took {} seconds'.format(time.time() - start))

出于某种奇怪的原因,递归中形状(4064512)的特征映射从形状(317510512)向上采样到(6341020512)的点花费了惊人的941秒!我开始用Theano重写这段代码,但是我应该看看我的代码的一些潜在问题吗?我现在的理由是,在CPU上计算这个问题很难,但我不确定这么简单的函数会有什么问题。此外,任何提示如何使这个功能更快将不胜感激!你知道吗


Tags: 函数cropimageindextimestartrepeatshape