在PyTorch的DataLoader中,getitem的idx是如何工作的?

2024-09-29 22:21:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我目前正在尝试使用PyTorch的DataLoader来处理数据,并将其输入到我的深度学习模型中,但是遇到了一些困难。在

我需要的数据的形状是(minibatch_size=32, rows=100, columns=41)。在我编写的自定义dataset类中的__getitem__coad如下所示:

def __getitem__(self, idx):
    x = np.array(self.train.iloc[idx:100, :])
    return x

我这样写的原因是因为我希望数据加载器一次处理形状(100, 41)的输入实例,我们有32个这样的单个实例。在

但是,我注意到,与我最初的想法相反,DataLoader传递给函数的idx参数不是连续的(这一点很重要,因为我的数据是时间序列数据)。例如,打印值给了我这样的结果:

^{pr2}$

这是正常行为吗?我发布这个问题是因为返回的数据批不是我最初想要的。在

在使用DataLoader之前,我用于预处理数据的原始逻辑是:

  1. txtcsv文件中读取数据。在
  2. 计算数据中有多少批,并相应地对数据进行切片。例如,由于一个输入实例的形状是(100, 41),其中32个是一个小批量,因此我们通常会得到大约100个批次,并相应地对数据进行整形。在
  3. 一个输入的形状是(32, 100, 41)。在

我不确定我应该如何处理DataLoader钩子方法。如有任何提示或建议,我们将不胜感激。提前谢谢。在


Tags: columns数据实例模型selfsizepytorchdataset
1条回答
网友
1楼 · 发布于 2024-09-29 22:21:56

定义idx的是sampler或{},如您所见{a1}(开源项目是您的朋友)。在这个code(和comment/docstring)中,您可以看到sampler和{}之间的区别。如果您查看here,您将看到如何选择索引:

def __next__(self):
    index = self._next_index()

# and _next_index is implemented on the base class (_BaseDataLoaderIter)
def _next_index(self):
    return next(self._sampler_iter)

# self._sampler_iter is defined in the __init__ like this:
self._sampler_iter = iter(self._index_sampler)

# and self._index_sampler is a property implemented like this (modified to one-liner for simplicity):
self._index_sampler = self.batch_sampler if self._auto_collation else self.sampler

请注意这是_SingleProcessDataLoaderIter实现;您可以找到_MultiProcessingDataLoaderIterhere(ofc,使用哪个取决于num_workers值,如您所见here)。回到采样器,假设您的数据集不是_DatasetKind.Iterable,并且您没有提供自定义采样器,这意味着您正在使用(dataloader.py#L212-L215):

^{pr2}$

让我们看看how the default BatchSampler builds a batch

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

非常简单:它从取样器获取索引,直到达到所需的批次大小。在

现在的问题是“在PyTorch的DataLoader中,getitem的idx是如何工作的?”可以通过查看每个默认采样器的工作方式来回答。在

  • SequentialSampler(这是完整的实现非常简单,不是吗?)公司名称:
class SequentialSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)
def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    return iter(torch.randperm(n).tolist())

因此,由于您没有提供任何代码,我们只能假设:

  1. 您正在数据加载器中使用shuffle=True
  2. 您正在使用自定义采样器
  3. 您的数据集是_DatasetKind.Iterable

相关问题 更多 >

    热门问题