如何在keras自定义回调中访问tf.data.Dataset?

2024-06-28 20:18:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我编写了一个自定义keras回调来检查来自生成器的增强数据。(请参阅this answer了解完整代码。)但是,当我尝试对tf.data.Dataset使用相同的回调时,它给了我一个错误:

  File "/path/to/tensorflow_image_callback.py", line 16, in on_batch_end
imgs = self.train[batch][images_or_labels]
TypeError: 'PrefetchDataset' object is not subscriptable

keras回调通常只与生成器一起工作,还是与我编写回调的方式有关?有没有办法修改回调或数据集以使其工作

我认为这个难题有三个部分。我对任何和所有的改变都持开放态度。首先,自定义回调类中的init函数:

class TensorBoardImage(tf.keras.callbacks.Callback):
    def __init__(self, logdir, train, validation=None):
        super(TensorBoardImage, self).__init__()
        self.logdir = logdir
        self.file_writer = tf.summary.create_file_writer(logdir)
        self.train = train
        self.validation = validation

第二,同一类中的on_batch_end函数

def on_batch_end(self, batch, logs):
    images_or_labels = 0 #0=images, 1=labels
    imgs = self.train[batch][images_or_labels]

第三,实例化回调

import tensorflow_image_callback
tensorboard_image_callback = tensorflow_image_callback.TensorBoardImage(logdir=tensorboard_log_dir, train=train_dataset, validation=valid_dataset)
model.fit(train_dataset,
          epochs=n_epochs,
          validation_data=valid_dataset, 
          callbacks=[
                    tensorboard_callback,
                    tensorboard_image_callback
                    ])

一些相关的线索还没有给我答案:

Accessing validation data within a custom callback

Create keras callback to save model predictions and targets for each batch during training


Tags: imageselfdatalabelstftensorflowbatchcallback
1条回答
网友
1楼 · 发布于 2024-06-28 20:18:01

使用^{}对我起作用的是以下内容:

__init__函数:

def __init__(self, logdir, train, validation=None):
    super(TensorBoardImage, self).__init__()
    self.logdir = logdir
    self.file_writer = tf.summary.create_file_writer(logdir)
    # #from keras generator
    # self.train = train
    # self.validation = validation
    #from tf.Data
    my_data = tfds.as_numpy(train)
    imgs = my_data['image']

然后on_batch_end

def on_batch_end(self, batch, logs):
    images_or_labels = 0 #0=images, 1=labels
    imgs = self.train[batch][images_or_labels]

    #calculate epoch
    n_batches_per_epoch = self.train.samples / self.train.batch_size
    epoch = math.floor(self.train.total_batches_seen / n_batches_per_epoch)

    #since the training data is shuffled each epoch, we need to use the index_array to find something which uniquely 
    #identifies the image and is constant throughout training
    first_index_in_batch = batch * self.train.batch_size
    last_index_in_batch = first_index_in_batch + self.train.batch_size
    last_index_in_batch = min(last_index_in_batch, len(self.train.index_array))
    img_indices = self.train.index_array[first_index_in_batch : last_index_in_batch]

    with self.file_writer.as_default():
        for ix,img in enumerate(imgs):
            #only post 1 out of every 1000 images to tensorboard
            if (img_indices[ix] % 1000) == 0:
                #instead of img_filename, I could just use str(img_indices[ix]) as a unique identifier
                #but this way makes it easier to find the unaugmented image
                img_filename = self.train.filenames[img_indices[ix]]

                #convert float to uint8, shift range to 0-255
                img -= tf.reduce_min(img)
                img *= 255 / tf.reduce_max(img)
                img = tf.cast(img, tf.uint8)
                img_tensor = tf.expand_dims(img, 0) #tf.summary needs a 4D tensor
                
                tf.summary.image(img_filename, img_tensor, step=epoch)

我不需要对实例化进行任何更改

我建议只在调试时使用它,否则它会将数据集中的每个第n个图像保存到tensorboard的每个历元中。最终可能会占用大量磁盘空间

相关问题 更多 >