在增强训练期间,不能在Keras iterator.py中的断点处停止

2024-10-02 22:26:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我创建了数据生成器类的两个实例,从keras Sequence类扩展而来,一个用于培训,一个用于验证数据。然而,在我的源代码级别,我只能看到验证生成器在每个历元之间重新迭代。我看不到训练生成器。因此,我无法验证培训数据的增加是否符合我的意图。在这些代码片段中,aug是一组参数,这些参数在序列的myDataGen扩展中传递给keras ImageDataGenerator实例。我通常不会增加验证数据,但我就是这样偶然发现这个难题的:

    aug = dict(fill_mode='nearest',
                        rotation_range=10,
                        zoom_range=0.3,
                        width_shift_range=0.1,
                        height_shift_range=0.1
                        )
    training_datagen = myDataGen(Xdata_train,ydata_train,**aug)
    validation_datagen = myDataGen(Xdata_test,ydata_test,**aug)

    history = model.fit(training_datagen,
                                validation_data=validation_datagen,
                                validation_batch_size=16,
                                epochs=50,
                                shuffle=False,
                                )

一切都正常,我得到了很好的结果,但我只是想确定一下增强的效果。因此,通过浏览keras中的各种函数,我可以收集到我编写的数据生成器填充了一个较低级别的tensorflow数据集,然后每个历元进行迭代。我只是看不出tensorflow数据集是如何在每个历元中增加的

现在,我还意外地发现,虽然fit方法不支持验证数据的生成器,但它确实可以工作,并且具有我希望用于训练生成器的有趣功能,即从磁盘重新读取数据,以便在我自己的源代码级别重新扩充

总之,我可以看到tensorflow Dataset.cache()方法可能在第一个历元之后将我的训练数据集存储在内存中。我是否可以以某种方式取消缓存()以强制重新读取和重新扩充,或者有人可以告诉我tensorflow数据集在迭代时如何调用扩充方法

嗯。这个线程TF Dataset API for Image augmentation清楚地表明,直接在tensorflow数据集API中编写增强方法很容易,但是参与者在注释中写道,您不能在tf.data.Dataset上使用keras.ImageDataGenerator。但我可以在keras模块中清楚地看到,我的keras数据集正在被“改编”为底层tf.data.dataset。如果这句话是真的,它将解释为什么我似乎无法突破ImageDataGenerator对训练数据的扩充。但这怎么可能是真的呢


Tags: 数据实例方法data源代码tensorflowrangedataset
1条回答
网友
1楼 · 发布于 2024-10-02 22:26:41

我的困惑来自一个初学者的错误,他们忽略了一个事实,即在keras源代码被编译到gpu上之后,当然不能在其级别上突破。但有趣的是,从这种混乱中产生的是,你可以为验证数据编写一个keras生成器,并在每个历元中中断它,因为它显然没有编译到gpu上。。。因为keras不支持验证数据的生成器!只是生成器处理得很好,没有运行时错误。一个不太清楚的发现,但希望它能帮助别人

相关问题 更多 >