用于短时学习的Tensorflow 2数据集

2024-09-26 22:52:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试实现一个Tensorflow数据集来解决一些快照学习问题。我想要实现的是:

  1. Sample k images each from n classes.
  2. Give each of the k images a label for the class (only unique for this batch so they can just be 1,2,...,n).
  3. Shuffle the images/labels in this batch.

Tensorflow版本为:2.0.0。目前,我使用的omniglot数据集具有一个文件结构:.\alphabet-dir\letter-dir\*.png,其中每个字母(所有字母表)都被视为一个类

接近

因为我在网上找不到任何指南,所以我只是尝试使用tensorflow的数据API已经提供的功能

我正在做的是:

  1. 初始化:

    (1)加载类的所有文件夹路径

    (2)然后为每个文件夹创建一个数据集(预处理将路径指定为标签)

    (3)将路径用作键,并将所有数据集放入字典中:

    self.class_names = Path(img_dir).glob("*/*")
    ...
    
    # create dictionary with one dataset for each class
    datasets = [(class_name, tf.data.Dataset.list_files(str(Path(class_name)/'*'))
                 .map(self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                 .prefetch(50))
                for class_name
                in self.class_names]
    
    self.datasets = dict(datasets)  
    
  2. 批处理:

    (1)字典中的样本n

    (2)为每个数据集获取k样本,并为类映射一个标签(0到n之间的数字作为一个热向量)

    (3)连接所有数据集并洗牌连接的集:

    selected_classes = random.sample(self.class_names, self.num_classes)
    
    datasets = [self.datasets.get(classname)
                    .take(self.k)
                    .map(self.map_with_one_hot_labels_for_class(selected_classes, classname))
                for classname
                in selected_classes]
    
    concatenated = datasets[0]
    for i in range(1, self.num_classes):
        concatenated = concatenated.concatenate(datasets[i])
    
    return concatenated.shuffle(self.num_classes*self.k)
    

问题

首先,我不知道这是否是一个好方法。如果有人能提出更好的解决方案,我很乐意接手

我遇到的问题是,这会消耗越来越多的内存。我试图排除prefetch(50),但这并没有改变任何事情。我怀疑连接的数据集从未被处理过

一旦不再使用数据集(显然我需要保留代表整个类的底层数据集…),是否有办法删除该数据集?或者我是否使用了错误的数据集

谢谢


Tags: the数据inself路径fornamesdir

热门问题