使用预训练模型和配置文件时,如何停止基于丢失的训练?

2024-05-18 12:04:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用一个更快的RCNN模型来训练一个对象检测器,使用管道配置文件。我知道只需直接取消(ctrl+c)即可停止培训。 我希望根据损失值自动停止培训。如何做到这一点? 我知道keras回调可以在监视历代时使用。在使用配置文件和预先培训的模型(监控步骤)时,是否有此类选项。 谢谢


Tags: 对象模型管道配置文件选项步骤检测器keras
1条回答
网友
1楼 · 发布于 2024-05-18 12:04:21

这可能只是一个黑客,但我找到了解决我问题的办法。 对象检测器需要安装tf_slim包。在tf_slim包中,有一个名为learning.py的模块。 到这一点的完整路径可能如下所示:/usr/local/lib/python3.6/site-packages/tf_slim/learning.py 这里,在learning.py的起始行764中,代码如下所示:

try:
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

我编写了一个小的if语句来检查total_loss最后五个值的最大值,如果低于某个阈值(在本例中为3),则生成should_stop{}。如下所示:

try:
  total_loss_list = []
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    total_loss_list.append(total_loss)
    if len(total_loss_list) > 5:
      if max(total_loss_list[-5:]) < 3:
        should_stop = True
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
  # OutOfRangeError is thrown when epoch limit per
  # tf.compat.v1.train.limit_epochs is reached.
  logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

如果损失值连续五步低于3,则训练停止。这样做的缺点是tf_slim的包分布必须改变。每次处理新的目标检测问题时,阈值损失值都会改变。更好的方法是使用配置文件,在其中提供阈值损失值。但我现在就到此为止。 如果其他人有更好的解决方案,请分享。 我希望这对某人有帮助。谢谢大家!

相关问题 更多 >