我编写了一个自定义keras回调来检查来自生成器的增强数据。(请参阅this answer了解完整代码。)但是,当我尝试对tf.data.Dataset
使用相同的回调时,它给了我一个错误:
File "/path/to/tensorflow_image_callback.py", line 16, in on_batch_end
imgs = self.train[batch][images_or_labels]
TypeError: 'PrefetchDataset' object is not subscriptable
keras回调通常只与生成器一起工作,还是与我编写回调的方式有关?有没有办法修改回调或数据集以使其工作
我认为这个难题有三个部分。我对任何和所有的改变都持开放态度。首先,自定义回调类中的init函数:
class TensorBoardImage(tf.keras.callbacks.Callback):
def __init__(self, logdir, train, validation=None):
super(TensorBoardImage, self).__init__()
self.logdir = logdir
self.file_writer = tf.summary.create_file_writer(logdir)
self.train = train
self.validation = validation
第二,同一类中的on_batch_end
函数
def on_batch_end(self, batch, logs):
images_or_labels = 0 #0=images, 1=labels
imgs = self.train[batch][images_or_labels]
第三,实例化回调
import tensorflow_image_callback
tensorboard_image_callback = tensorflow_image_callback.TensorBoardImage(logdir=tensorboard_log_dir, train=train_dataset, validation=valid_dataset)
model.fit(train_dataset,
epochs=n_epochs,
validation_data=valid_dataset,
callbacks=[
tensorboard_callback,
tensorboard_image_callback
])
一些相关的线索还没有给我答案:
Accessing validation data within a custom callback
Create keras callback to save model predictions and targets for each batch during training
使用^{} 对我起作用的是以下内容:
__init__
函数:然后
on_batch_end
:我不需要对实例化进行任何更改
我建议只在调试时使用它,否则它会将数据集中的每个第n个图像保存到tensorboard的每个历元中。最终可能会占用大量磁盘空间
相关问题 更多 >
编程相关推荐