Pytorch的数据加载程序什么时候洗牌?

2024-09-30 18:26:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经多次使用pytorch dataloader的shuffle选项。但我想知道这种洗牌是什么时候发生的,以及它是否在迭代过程中动态执行。以以下代码为例:

namesDataset = NamesDataset()
namesTrainLoader = DataLoader(namesDataset, batch_size=16, shuffle=True)
for batch_data in namesTrainLoader:
    print(batch_data)

当我们定义“namesTrainLoader”时,这是否意味着洗牌已经完成,接下来的迭代将基于固定的数据顺序?在定义namesTrainLoader之后,for循环中是否存在任何随机性

我试图用一些特殊值替换“批次数据”的一半:

for batch_data in namesTrainLoader:
    batch_data[:8] = special_val
    pre = model(batch_data)

让我们假设将有无限多的epoche,“model”最终会看到“namesTrainLoader”中的所有数据吗?或者“namesTrainLoader”的一半数据实际上丢失到了“model”中


Tags: 数据infordatamodel定义过程选项
2条回答

您可以检查PyTorch的torch.utils.data.DataLoaderhere实现

如果指定shuffle=True,将使用^{}SequentialSampler否则)

当创建DataLoader的实例时,不会对任何内容进行洗牌,它只是实例化对象和其他类似设置的必要私有成员

当您在迭代过程中发出特殊的__iter__方法时,会返回一个名为_SingleProcessDataLoader(self)的特殊对象,它是一个数据生成器(可能是批处理、洗牌等,假设您不使用多处理)

要找到所有私有方法和帮助器相关的方法,有点像兔子洞,但它基本上是使用底层的sampler来获取用于从torch.utils.data.Dataset获取样本的索引

取样器一直运行到耗尽,过程重复(通常是一个历元)

Will there be any randomness in the for loop after namesTrainLoader was defined?

在每个周期开始时/epochRandomSampler洗牌索引,因此是的,它将在每个epoch之前随机化(当调用__iter__并返回新的_SingleProcessDataLoader(self)),这可以无限期地进行

[...] will "model" eventually see all the data in "namesTrainLoader"?

是的,它很可能最终会看到所有的数据点

当迭代器被创建时,会发生洗牌。在for循环的情况下,这发生在for循环开始之前

您可以使用以下工具手动创建迭代器:

# Iterator gets created, the data has been shuffled at this point.
data_iterator = iter(namesTrainLoader)

默认情况下,如果设置shuffle=True(不提供自己的采样器),则数据加载器使用^{}。它的实现非常简单,通过查看^{}方法,您可以看到在创建迭代器时数据被洗牌的位置:

def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    return iter(torch.randperm(n).tolist())

return语句是进行洗牌的重要部分。它只是创建索引的随机排列

这意味着您将在每次完全使用迭代器时看到整个数据集,只是每次的顺序不同。因此没有数据丢失(不包括drop_last=True的情况),您的模型将在每个历元看到所有数据

相关问题 更多 >