生成和使用多对象数据集的工具
multiobject的Python项目详细描述
多对象数据集
生成和使用多对象数据集的工具。 数据集由图像和标签字典组成,其中每个图像位于 标记为1)其中的对象数和2)每个对象的属性。在
使用数据集只需要numpy
,因为数据集是.npz
。
生成精灵需要scikit-image
。使用的工具
提供了PyTorch中的数据集,并提供了使用示例。在
基本用法(pip包)
- 或者下载
generated/
中的datasets, 或generate a new one。在 - 将
.npz
数据集放入/path/to/data/
。在 pip install multiobject
- Pythorch中的用法:
frommultiobject.pytorchimportMultiObjectDataLoader,MultiObjectDatasetdataset_path='/path/to/data/some_dataset.npz'train_set=MultiObjectDataset(dataset_path,train=True)test_set=MultiObjectDataset(dataset_path,train=False)train_loader=MultiObjectDataLoader(train_set,batch_size=batch_size,shuffle=True)test_loader=MultiObjectDataLoader(test_set,batch_size=test_batch_size)
在
运行演示
^{pr2}$可用数据集
数据集在./generated/
中以.npz
文件的形式提供。在
dSprites1
黑色画布上带有单色dSprites的64x64 RGB二进制图像。 精灵是18x18和7种不同的颜色,它们可以重叠(求和和和剪切)。在
- 100k图像,每张图像1个精灵[10.6MB]
- 10万张图像,每张图像有1个精灵,较大的精灵(最大28x28)[12.4 MB]
- 每幅图像包含0、1或2个(统一)精灵的100k图像[11 MB]
二进制MNIST
一个单通道644位数字的黑色画布。 数字被重新缩放到18x18并进行二值化,它们可以重叠(求和和和剪裁)。 仅使用MNIST训练集中的数字(60k)。在
- 100k张图像,每张图像1位数[4.5 MB]
- 100k张图像,每张图像有0、1或2位(统一)位数[4.8 MB]
生成新数据集
- 在
克隆此回购。在
在 - 在
请参见requirements,或设置虚拟环境,如下所示:
在conda create --name multiobject python=3.7 conda activate multiobject pip install -r requirements.txt
- 在
可选:生成新类型的精灵:
- 创建一个包含函数的文件
sprites/xyz.py
generate_xyz()
,其中“xyz”表示新的sprite类型 - 在
generate_dataset.py
中,添加对generate_xyz()
的调用以生成 更正精灵,并将'xyz'
添加到支持的精灵列表中
- 创建一个包含函数的文件
- 在
使用所需的sprite类型作为
--type
参数调用generate_dataset.py
。 示例:
在python generate_dataset.py --type dsprites
从集合生成数据集时,精灵属性将自动管理 具有每个精灵标签的精灵。但是,由于它们是特定于数据集的, 它们必须在创建精灵时定义。在
Note.目前,必须直接在generate_dataset.py
中自定义以下参数:
- 目标数的概率分布
- 图像大小
- 精灵大小
- 数据集大小
- 精灵是否可以重叠
要求
要生成数据集:
numpy==1.18.1
matplotlib==3.1.2
scikit_image==0.16.2
tqdm==4.41.1
pillow==7.0.0
要运行示例或使用pytorch工具:
torch==1.4.0
torchvision==0.5.0
脚注
这实际上是原始dSprites的扩展 数据集到许多对象和彩色图像。↩
- 项目
标签: