我正在尝试实现一个Tensorflow数据集来解决一些快照学习问题。我想要实现的是:
- Sample k images each from n classes.
- Give each of the k images a label for the class (only unique for this batch so they can just be 1,2,...,n).
- Shuffle the images/labels in this batch.
Tensorflow版本为:2.0.0。目前,我使用的omniglot数据集具有一个文件结构:.\alphabet-dir\letter-dir\*.png
,其中每个字母(所有字母表)都被视为一个类
因为我在网上找不到任何指南,所以我只是尝试使用tensorflow的数据API已经提供的功能
我正在做的是:
初始化:
(1)加载类的所有文件夹路径
(2)然后为每个文件夹创建一个数据集(预处理将路径指定为标签)
(3)将路径用作键,并将所有数据集放入字典中:
self.class_names = Path(img_dir).glob("*/*")
...
# create dictionary with one dataset for each class
datasets = [(class_name, tf.data.Dataset.list_files(str(Path(class_name)/'*'))
.map(self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.prefetch(50))
for class_name
in self.class_names]
self.datasets = dict(datasets)
批处理:
(1)字典中的样本n类
(2)为每个数据集获取k样本,并为类映射一个标签(0到n之间的数字作为一个热向量)
(3)连接所有数据集并洗牌连接的集:
selected_classes = random.sample(self.class_names, self.num_classes)
datasets = [self.datasets.get(classname)
.take(self.k)
.map(self.map_with_one_hot_labels_for_class(selected_classes, classname))
for classname
in selected_classes]
concatenated = datasets[0]
for i in range(1, self.num_classes):
concatenated = concatenated.concatenate(datasets[i])
return concatenated.shuffle(self.num_classes*self.k)
首先,我不知道这是否是一个好方法。如果有人能提出更好的解决方案,我很乐意接手
我遇到的问题是,这会消耗越来越多的内存。我试图排除prefetch(50)
,但这并没有改变任何事情。我怀疑连接的数据集从未被处理过
一旦不再使用数据集(显然我需要保留代表整个类的底层数据集…),是否有办法删除该数据集?或者我是否使用了错误的数据集
谢谢
目前没有回答
相关问题 更多 >
编程相关推荐