“Subset”与“SubsetRandomSampler”的区别

2024-10-02 12:31:45 发布

您现在位置:Python中文网/ 问答频道 /正文

最近,我尝试使用SubsetSubsetRandomSampler方法来解决K-折叠交叉验证问题

当我使用Subset方法时,CIFAR10数据集的第一个历元精度是88%,然而,当我使用SubsetRandomSampler方法时,CIFAR10数据集的第一个历元精度只有16%。这真的让我困惑,我也不知道。有人知道吗?非常感谢

方法的代码:

 for fold,(trainLoader,valLoader) in enumerate(kf.split(trainSet)):

      trainSetBasic = torch.utils.data.Subset(trainSet, trainLoader)
      valSetBasic = torch.utils.data.Subset(trainSet, valLoader)

      dataloaders = {
  'train': DataLoader(trainSetBasic, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
  'val': DataLoader(valSetBasic, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
  }

方法的代码:

for fold,(trainLoader,valLoader) in enumerate(kf.split(trainSet)):

      train_subsampler = torch.utils.data.SubsetRandomSampler(trainLoader)
      val_subsampler = torch.utils.data.SubsetRandomSampler(valLoader)

      dataloaders = {
  'train': DataLoader(trainSet, batch_size=BATCH_SIZE, sampler=train_subsampler, shuffle=False, num_workers=2),
  'val': DataLoader(trainSet, batch_size=BATCH_SIZE, sampler=val_subsampler, shuffle=False, num_workers=2)
  }

其他部分的代码都是一样的


Tags: 方法datasizebatchtrainutilstorchsubset
1条回答
网友
1楼 · 发布于 2024-10-02 12:31:45

在我看来,这两种代码之间的区别在于手动洗牌索引

数据集是如何定义的?类是否在内部排序?如果是,则使用SubsetRandomSampler只将第一个类传递给train_subsampler,将最后一个类传递给val_subsampler。这就解释了错误的准确性

相关问题 更多 >

    热门问题