Keras在Callb中获得了批处理中使用的精确训练示例

2024-05-06 13:08:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我在Keras训练神经网络有问题。每一个时期,损失都会逐渐减少,达到约1e-9,然后在纪元中期的某个地方(可能是任何地方),损失会上升到5e-5,最终在每个历元的最终损失相同的情况下稳定下来。我相信这是由于数据集中的一些脏数据导致模型无法训练到某个特定点,尽管我真的不确定。在

为了测试我的假设,我想创建一个自定义的Keras回调对象,它将确定批处理后的损失是否有足够大的跳跃,并指出是哪个批导致了跳跃。问题是提供给keras.callbacks.Callback.on_batch_endbatch参数只是批处理编号,而不是该批处理中使用的训练示例。此外,传入的logsdict也只包含loss和{}。在

这意味着我无法确定是哪些数据导致了损失的增加。有没有一种方法可以确定每一个时代引起跳跃的确切训练例子?有什么方法可以让我在回叫中访问它吗?在


Tags: 数据对象方法模型地方batch情况神经网络