如何使用pytorch将数据集拆分为自定义训练集和自定义验证集?

2024-10-01 19:22:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用的是一个非torchvision数据集,我用ImageFolder方法提取了它。我试图将数据集分成20%的验证集和80%的训练集。我只能从PyTorch库中找到这个方法(random_split),它允许分割数据集。然而,这每次都是随机的。我想知道有没有一种方法可以在PyTorch库中以特定数量分割数据集

这是我提取数据集并随机拆分它的代码

transformations = transforms.Compose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

TrafficSignSet = datasets.ImageFolder(root='./train/', transform=transformations)

####### split data
train_size = int(0.8 * len(TrafficSignSet))
test_size = len(TrafficSignSet) - train_size
train_dataset_split, test_dataset_split = torch.utils.data.random_split(TrafficSignSet, [train_size, test_size])

#######put into a Dataloader
train_dataset = torch.utils.data.DataLoader(train_dataset_split, batch_size=32, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test_dataset_split, batch_size=32, shuffle=True)

Tags: 数据方法testdatasizetrainutilsrandom
1条回答
网友
1楼 · 发布于 2024-10-01 19:22:28

如果您查看^{}的“引擎盖下”,您将看到它使用^{}进行实际拆分。您可以使用固定索引自己执行此操作:

import random

indices = list(range(len(TrafficSignSet))
random.seed(310)  # fix the seed so the shuffle will be the same everytime
random.shuffle(indices)
train_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[:train_size])
val_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[train_size:])

相关问题 更多 >

    热门问题