我使用的是一个非torchvision数据集,我用ImageFolder方法提取了它。我试图将数据集分成20%的验证集和80%的训练集。我只能从PyTorch库中找到这个方法(random_split),它允许分割数据集。然而,这每次都是随机的。我想知道有没有一种方法可以在PyTorch库中以特定数量分割数据集
这是我提取数据集并随机拆分它的代码
transformations = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
TrafficSignSet = datasets.ImageFolder(root='./train/', transform=transformations)
####### split data
train_size = int(0.8 * len(TrafficSignSet))
test_size = len(TrafficSignSet) - train_size
train_dataset_split, test_dataset_split = torch.utils.data.random_split(TrafficSignSet, [train_size, test_size])
#######put into a Dataloader
train_dataset = torch.utils.data.DataLoader(train_dataset_split, batch_size=32, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test_dataset_split, batch_size=32, shuffle=True)
如果您查看^{} 的“引擎盖下”,您将看到它使用^{} 进行实际拆分。您可以使用固定索引自己执行此操作:
相关问题 更多 >
编程相关推荐