如何处理Unet体系结构中的奇数分辨率

2024-10-01 05:05:50 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在PyTorch中实现一个基于U-Net的体系结构。在火车时刻,我有大小为256x256的补丁,这不会引起任何问题。但是在测试时,我有全高清图像(1920x1080)。这会在跳过连接期间导致问题

1920x1080进行3次下采样240x135。如果我再向下采样一次,分辨率将变为120x68,当向上采样时,分辨率将为240x136。现在,我无法连接这两个要素贴图。我怎样才能解决这个问题

PS:我认为这是一个相当普遍的问题,但我没有得到任何解决方案,甚至没有在网上提到这个问题。我错过什么了吗


Tags: 图像net体系结构分辨率pytorch解决方案ps要素
1条回答
网友
1楼 · 发布于 2024-10-01 05:05:50

在分段网络中,这是一个非常常见的问题,在解码过程中经常涉及到跳转连接。网络通常(取决于实际的架构)需要输入大小,其边长为最大跨步(8、16、32等)的整数倍

主要有两种方式:

  1. 将输入大小调整为最接近的可行大小
  2. 将输入填充到下一个更大的可行大小

我更喜欢(2),因为(1)会导致所有像素的像素级别发生微小变化,导致不必要的模糊。请注意,在这两种方法中,我们通常需要在之后恢复原始形状

此任务我最喜欢的代码段(高度/宽度的对称填充):

import torch
import torch.nn.functional as F

def pad_to(x, stride):
    h, w = x.shape[-2:]

    if h % stride > 0:
        new_h = h + stride - h % stride
    else:
        new_h = h
    if w % stride > 0:
        new_w = w + stride - w % stride
    else:
        new_w = w
    lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
    lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
    pads = (lw, uw, lh, uh)

    # zero-padding by default.
    # See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
    out = F.pad(x, pads, "constant", 0)

    return out, pads

def unpad(x, pad):
    if pad[2]+pad[3] > 0:
        x = x[:,:,pad[2]:-pad[3],:]
    if pad[0]+pad[1] > 0:
        x = x[:,:,:,pad[0]:-pad[1]]
    return x

测试片段:

x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network 
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape

print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)

输出:

Original:  torch.Size([4, 3, 1080, 1920])
Padded:  torch.Size([4, 3, 1088, 1920])
Recovered:  torch.Size([4, 3, 1080, 1920])

参考:https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33

相关问题 更多 >