我想从3个不同的文件夹中获得3批图像。我已经用pytorch编写了自定义数据加载器。但它返回的列表包含所有批次,而不是一次包含一个批次
#custom data loader
class set(Dataset):
def __init__(self, dataset_input, dataset_expertA, dataset_expertB):
self.dataset1 = dataset_input
self.dataset2 = dataset_expertA
self.dataset3 = dataset_expertB
def __getitem__(self, index):
x1 = self.dataset1[index]
x2 = self.dataset2[index]
x3 = self.dataset3[index]
return x1, x2, x3
def __len__(self):
return len(self.dataset1)
input_path = "/content/gdrive/My Drive/project/input/"
dataset = datasets.ImageFolder(root= input_path, transform=transforms.Compose([
transforms.Resize([64,64]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
expertA_path = "/content/gdrive/My Drive/project/expertA/"
datasetA = datasets.ImageFolder(root= expertA_path, transform=transforms.Compose([
transforms.Resize([64,64]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
expertB_path = "/content/gdrive/My Drive/project/expertB/"
datasetB = datasets.ImageFolder(root= expertB_path, transform=transforms.Compose([
transforms.Resize([64,64]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
data = set(dataset, datasetA, datasetB)
dataloader = torch.utils.data.DataLoader(data, batch_size=64,
shuffle=True, num_workers=2)
for i, (inp, expA, expB) in enumerate(dataloader):
print(inp.shape)
break
这会打印错误,即inp是列表,当我打印(inp[0].shape)时,我得到了正确的形状,我认为inp包含所有批次,即inp[0],inp[1]
我在数据加载程序代码中犯了什么错误
^{} 返回(图像,标签)的元组,因此
inp
也是元组,其中inp[0]
是图像,而inp[1]
是它们对应的标签。这同样适用于expA
和expB
如果您只想要没有标签的图像,可以忽略标签,在访问自定义数据集中的数据时只返回图像:
相关问题 更多 >
编程相关推荐