在Pytorch中加载数据,用于在同一文件夹中包含所有类的数据集

2024-09-24 02:23:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我是新的深入学习和学习。我有6000个图像的数据集,在一个文件夹中有所有四个类。我使用以下代码片段上传数据

torchvision.datasets.ImageFolder(root='/content/drive/My Drive/DFU/base_dir/train_dir', transform=None)

我了解到,对于ImageFolder,应该根据类标签将图像组织到子文件夹中。但是,我的数据集在一个文件夹中包含所有四个类图像。我有一个.csv文件,其中包含每个图像的一个热编码类标签。如何将我的数据集加载到Pytorch? .CSV FILE


Tags: 数据代码图像文件夹mydirroot标签
1条回答
网友
1楼 · 发布于 2024-09-24 02:23:57

最简单的解决方案是根据csv文件将图像重新组织到类子文件夹中,并按照ImageFolder的预期加载:

import pandas as pd
from pathlib import Path

root = '/content/drive/My Drive/DFU/base_dir/train_dir'
my_csv_file = ...

# Loading csv as {image:class,...} format
df = pd.read_csv(my_csv_file).set_index('images')
class_dict = df.idxmax(axis="columns").to_dict()

# Moving files to class-named subfolders
for path in Path(root).iterdir():
    if path.is_file() and path.name in class_dict.keys():
        path.rename(Path(path.parent, class_dict[path.name], path.name)

# Loading dataset
dataset = torchvision.datasets.ImageFolder(root=root, transform=None)

相关问题 更多 >