如何加速下面的代码?无中心元素实现maxpool

2024-06-25 06:33:54 发布

您现在位置:Python中文网/ 问答频道 /正文

我知道maxpool,我在pytorch里用它。扩展参数的Maxpool如下: maxpool, dilated
现在我需要一种特殊形式的maxpool,在不使用中心元素的情况下执行maxpool。也就是说内核大小是3X3,但是应该删除中心元素。因此,结果应该来自其余8个元素。
现在我使用for循环,如何使用numpy或pytorch或其他任何东西来加速这个过程?你知道吗

import numpy as np
from timeit import default_timer as timer


def MaxPool_special(kh, kw, arr):
    """
    to do maxpool without central element
    :param kh:  should always be 3
    :param kw:  should always be 3
    :param arr:  the input array
    :return: arr_res: output array
    """
    h, w = arr.shape[:2]

    arr_res = np.array([[maxpool_ij(i, j, arr, kh, kw) for j in range(w)] for i in range(h)])

    return arr_res


def maxpool_ij(i, j, arr, dh, dw):
    """
    find the maximum value around point(i,j) with dilated parameter
    """
    Mmax = None
    imin, imax = i - dh, i + dh
    jmin, jmax = j - dw, j + dw
    if imin >= 0 and imax < h and jmin >= 0 and jmax < w:
        Mmax = np.max(
            arr[[imin, imin, imin, i, i, imax, imax, imax], [jmin, j, jmax, jmin, jmax, jmin, j, jmax]])
    elif imin < 0 and jmin < 0:
        Mmax = np.max(arr[[i, imax, imax], [jmax, j, jmax]])
    elif imin < 0 and jmax >= w:
        Mmax = np.max(arr[[i, imax, imax], [jmin, jmin, j]])
    elif imax >= h and jmin < 0:
        Mmax = np.max(arr[[imin, imin, i], [j, jmax, jmax]])
    elif imax >= h and jmax >= w:
        Mmax = np.max(arr[[imin, imin, i], [jmin, j, jmin]])
    elif imin < 0:
        Mmax = np.max(arr[[i, i, imax, imax, imax], [jmin, jmax, jmin, j, jmax]])
    elif imax >= h:
        Mmax = np.max(arr[[imin, imin, imin, i, i], [jmin, j, jmax, jmin, jmax]])
    elif jmin < 0:
        Mmax = np.max(arr[[imin, imin, i, imax, imax], [j, jmax, jmax, j, jmax]])
    elif jmax >= w:
        Mmax = np.max(arr[[imin, imin, i, imax, imax], [jmin, j, jmin, jmin, j]])

    assert Mmax, f'Wrong logic above!{imin, imax, jmin, jmax, h, w}'

    return Mmax

#  generate input array
h, w = 400, 500
arr = np.random.randint(0, 256, h * w).reshape(h, w)

tic = timer()
grayPool = MaxPool_special(3, 3, arr)
toc = timer()
print(f'time cost for for-loops: {toc - tic}')

请帮我加速这个代码,谢谢!你知道吗


Tags: and元素fornparraymaxtimerarr
1条回答
网友
1楼 · 发布于 2024-06-25 06:33:54

使用torch.nn.Unfold可以实现没有中心元素的Maxpool。示例如下:

h, w = 7, 10
x = torch.arange(0,h*w,dtype=torch.float).reshape(1,1,h,w)

"""
x.shape: torch.Size([1, 1, 7, 10])
x:
tensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
          [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],
          [40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],
          [50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],
          [60., 61., 62., 63., 64., 65., 66., 67., 68., 69.]]]])
"""

Unfold = torch.nn.Unfold(kernel_size=(3,5), stride=(2,2))
xUfd = Unfold(x)
"""
xUfd.shape: torch.Size([1, 15, 9])
xUfd:
tensor([[[ 0.,  2.,  4., 20., 22., 24., 40., 42., 44.],
         [ 1.,  3.,  5., 21., 23., 25., 41., 43., 45.],
         [ 2.,  4.,  6., 22., 24., 26., 42., 44., 46.],
         [ 3.,  5.,  7., 23., 25., 27., 43., 45., 47.],
         [ 4.,  6.,  8., 24., 26., 28., 44., 46., 48.],
         [10., 12., 14., 30., 32., 34., 50., 52., 54.],
         [11., 13., 15., 31., 33., 35., 51., 53., 55.],
         [12., 14., 16., 32., 34., 36., 52., 54., 56.],
         [13., 15., 17., 33., 35., 37., 53., 55., 57.],
         [14., 16., 18., 34., 36., 38., 54., 56., 58.],
         [20., 22., 24., 40., 42., 44., 60., 62., 64.],
         [21., 23., 25., 41., 43., 45., 61., 63., 65.],
         [22., 24., 26., 42., 44., 46., 62., 64., 66.],
         [23., 25., 27., 43., 45., 47., 63., 65., 67.],
         [24., 26., 28., 44., 46., 48., 64., 66., 68.]]])
"""

xUfd = xUfd[:,:,[0,1,2,3,5,6,7,8]]
xUfd = torch.max(xUfd, 2).values.reshape(3,5)
"""
xUfd.shape: torch.Size([3, 5])
xUfd:
tensor([[44., 45., 46., 47., 48.],
        [54., 55., 56., 57., 58.],
        [64., 65., 66., 67., 68.]])
"""

相关问题 更多 >