从多个不同大小的数据集加载Pythorch数据

2024-10-01 11:39:06 发布

您现在位置:Python中文网/ 问答频道 /正文

我有多个数据集,每个都有不同数量的图像(和不同的图像维度)。在训练循环中,我希望从所有的数据集中随机加载一批图像,但这样每个批只包含来自单个数据集的图像。例如,我有数据集A、B、C、D,每个都有images 01.jpg、02.jpg、…n.jpg(其中n取决于数据集),假设批大小为3。例如,在第一个加载的批处理中,我可能会在下一个批处理[D/01.jpg,D/05.jpg,D/12.jpg]中获取图像[B/02.jpg,B/06.jpg,B/12.jpg],等等

到目前为止,我考虑了以下几点:

  1. 为每个数据集使用不同的数据加载器,例如dataloaderA、dataloaderB等,然后在每个训练循环中随机选择一个DataLoader并从中获取一个批。然而,这将需要一个for循环,对于大量的数据集,它将非常慢,因为它不能在工作人员之间进行并行操作。在
  2. 使用单个数据加载器将所有数据集中的所有图像放在一起,但使用自定义的collate_fn,它将仅使用来自同一数据集的图像创建批处理。(我不知道该怎么做。)
  3. 我已经看过Concatdataset类,但是从它的源代码来看,如果我使用它并尝试获取一个新的批处理,其中的图像将从我不想要的不同数据集中混合。在

最好的办法是什么?谢谢!在


Tags: 数据图像for数量源代码fnjpgimages
1条回答
网友
1楼 · 发布于 2024-10-01 11:39:06

您可以使用^{},并向^{}提供一个batch_sampler。在

concat_dataset = ConcatDataset((dataset1, dataset2))

ConcatDataset.comulative_sizes将为您提供每个数据集之间的边界:

^{pr2}$

现在,您可以使用ds_indices来创建批处理采样器。请参阅the source for ^{}以获取参考。您的批处理采样器只需返回一个包含N个随机索引的列表,该列表将遵循ds_indices边界。这将保证批处理将包含来自同一数据集的元素。在

相关问题 更多 >