如何使用火炬视觉.transforms用于Pythorch中分段任务的数据扩充?

2024-09-26 17:50:40 发布

您现在位置:Python中文网/ 问答频道 /正文

我对Pythorch中执行的数据扩充有点困惑。在

因为我们处理的是分割任务,所以我们需要数据和掩码来进行相同的数据扩充,但是有些数据是随机的,比如随机旋转。在

Keras提供random seed保证数据和掩码执行相同的操作,如下面的代码所示:

    data_gen_args = dict(featurewise_center=True,
                         featurewise_std_normalization=True,
                         rotation_range=25,
                         horizontal_flip=True,
                         vertical_flip=True)


    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 1
    image_generator = image_datagen.flow(train_data, seed=seed, batch_size=1)
    mask_generator = mask_datagen.flow(train_label, seed=seed, batch_size=1)

    train_generator = zip(image_generator, mask_generator)

我在Pythorch官方文档中没有找到类似的描述,所以我不知道如何确保数据和掩码可以同步处理。在

Pythorch确实提供了这样一个函数,但是我想将它应用到一个定制的数据加载器。在

例如:

^{pr2}$

在这种情况下,img和mask将分别进行变换,因为随机旋转等操作是随机的,因此掩模和图像之间的对应关系可能会改变。换句话说,图像可能已经旋转,但遮罩没有这样做。在

编辑1

我在augmentations.py中使用了该方法,但出现了一个错误:

Traceback (most recent call last):
  File "test_transform.py", line 87, in <module>
    for batch_idx, image, mask in enumerate(train_loader):
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 103, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/data.py", line 164, in __getitem__
    img, mask = self.transforms(img, mask)
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/augmentations.py", line 17, in __call__
    img, mask = a(img, mask)
TypeError: __call__() takes 2 positional arguments but 3 were given

这是我的__getitem__()的代码:

data_transforms = {
    'train': Compose([
        RandomHorizontallyFlip(),
        RandomRotate(degree=25),
        transforms.ToTensor()
    ]),
}

train_set = DatasetUnetForTestTransform(fold=args.fold, random_index=args.random_index,transforms=data_transforms['train'])

# __getitem__ in class DatasetUnetForTestTransform
def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))
    temp_img = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_img, temp_label = crop_data_label_from_0(temp_img, temp_label)
    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = T.ToPILImage()(np.uint8(img))
        mask = T.ToPILImage()(np.uint8(mask))
        img, mask = self.transforms(img, mask)

    img = T.ToTensor()(img).copy()
    mask = T.ToTensor()(mask).copy()
    return img, mask

编辑2

我发现在ToTensor之后,相同标签之间的骰子变成了255而不是1,如何修复它?在

# Dice computation
def DSC_computation(label, pred):
    pred_sum = pred.sum()
    label_sum = label.sum()
    inter_sum = np.logical_and(pred, label).sum()
    return 2 * float(inter_sum) / (pred_sum + label_sum)

如果需要更多的代码来解释这个问题,请随时询问。在


Tags: inpyselfhomeimgdataindexnp
2条回答

torchvision还提供类似的函数[document]。在

这里有个简单的例子

import torchvision
from torchvision import transforms

trans = transforms.Compose([transforms.CenterCrop((178, 178)),
                                    transforms.Resize(128),
                                    transforms.RandomRotation(20),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dset = torchvision.datasets.MNIST(data_root, transforms=trans)

编辑

自定义您自己的CelebA数据集时的一个简单示例。注意,要应用转换,需要调用transform中的transform列表。在

^{pr2}$

编辑2

我可能第一眼就漏掉了一些东西。问题的重点是如何对img和标签应用“相同”的数据预处理。据我所知,Pythorch没有内置函数。所以,我以前做的就是自己去实现扩充。在

^{3}$

注意输入应该是PIL格式。有关详细信息,请参见this。在

需要输入参数(如RandomCrop)的转换有一个get_param方法,该方法将返回该特定转换的参数。然后,可以使用变换的功能接口将其应用于图像和遮罩:

from torchvision import transforms
import torchvision.transforms.functional as F

i, j, h, w = transforms.RandomCrop.get_params(input, (100, 100))
input = F.crop(input, i, j, h, w)
target = F.crop(target, i, j, h, w)

此处提供样品: https://github.com/pytorch/vision/releases/tag/v0.2.0

此处提供VOC和COCO的完整示例: https://github.com/pytorch/vision/blob/master/references/segmentation/transforms.pyhttps://github.com/pytorch/vision/blob/master/references/segmentation/train.py

关于错误

未重写ToTensor()以处理其他掩码参数,因此它不能位于data_transforms中。此外,__getitem__在返回它们之前,img和{}都是{}。在

^{pr2}$

相关问题 更多 >

    热门问题