Pytorch数据加载程序,用于读取大型拼花地板/csv文件

2024-10-06 12:27:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图让Pytorch训练单个拼花地板文件的记录,而不必一次在内存中读取整个文件,因为它不适合内存。由于该文件是远程存储的,我宁愿将其作为单个文件保存,因为对许多文件使用IO进行培训非常昂贵。当我想指定DataLoader中的批数时,如何在培训期间使用Pytorch的IterableDatasetDataset读取文件中较小的块?我知道映射样式Dataset在这种情况下不起作用,因为我需要一个文件中的所有内容,而不是读取每个文件的索引

我设法用tfio.IODatasettf.data.Dataset在Tensorflow中实现了这一点,但我找不到在Pytorch中实现它的等效方法


Tags: 文件内存io内容远程记录情况样式
1条回答
网友
1楼 · 发布于 2024-10-06 12:27:02

我发现了一个使用torch.utils.data.Dataset的解决方法,但是数据必须事先使用dask进行处理,这样每个分区都是一个用户,存储为自己的拼花文件,但以后只能读取一次。在下面的代码中,标签和数据分别存储用于多变量时间序列分类问题(但也可以很容易地适应其他任务):

import dask.dataframe as dd
import pandas as pd
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, IterableDataset, Dataset

# Breakdown file
raw_ddf = dd.read_parquet("data.parquet") # Read huge file using dask
raw_ddf = raw_ddf.set_index("userid") # set the userid as index
userids = raw_ddf.index.unique().compute().values.tolist() # get a list of indices
new_ddf = raw_ddf.repartition(divisions = userids) # repartition by userids
new_ddf.to_parquet("my_folder") # this will save each user as its own parquet file within "my_folder"

# Dask to read the partitions
train_ddf = dd.read_parquet("my_folder/*.parquet") # read all files

# Read labels file
labels_df = pd.read("label.csv")
y_labels = np.array(labels_df["class"])

# Define the Dataset class
class UsersDataset(Dataset):
    def __init__(self, dask_df, labels):
        self.dask_df = dask_df
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx): 
        X_df = self.dask_df.get_partition(idx).compute()
        X = np.row_stack([X_df])
        X_tensor = torch.tensor(X, dtype=torch.float32)
        y = self.labels[idx]
        y_tensor = torch.tensor(y, dtype=torch.long)
        sample = (X_tensor, y_tensor) 
        return sample

# Create a Dataset object
user_dataset = UsersDataset(dask_df=ddf_train, labels = y_train) 

# Create a DataLoader object
dataloader = DataLoader(user_dataset, batch_size=4, shuffle=True, num_workers=0)

# Print output of the first batch to ensure it works
for i_batch, sample_batched in enumerate(dataloader): 
    print("Batch number ", i_batch)
    print(sample_batched[0]) # print X
    print(sample_batched[1]) # print y

    # stop after first batch.
    if i_batch == 0:
        break

我想知道在使用>;=2名工人读取数据,无重复条目。非常感谢您对这方面的任何见解

相关问题 更多 >