try:
while not sv.should_stop():
total_loss, should_stop = train_step_fn(sess, train_op, global_step,
train_step_kwargs)
if should_stop:
logging.info('Stopping Training.')
sv.request_stop()
break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)
这可能只是一个黑客,但我找到了解决我问题的办法。 对象检测器需要安装
tf_slim
包。在tf_slim
包中,有一个名为learning.py
的模块。 到这一点的完整路径可能如下所示:/usr/local/lib/python3.6/site-packages/tf_slim/learning.py
这里,在learning.py
的起始行764中,代码如下所示:我编写了一个小的}。如下所示:
if
语句来检查total_loss
最后五个值的最大值,如果低于某个阈值(在本例中为3),则生成should_stop
{如果损失值连续五步低于3,则训练停止。这样做的缺点是
tf_slim
的包分布必须改变。每次处理新的目标检测问题时,阈值损失值都会改变。更好的方法是使用配置文件,在其中提供阈值损失值。但我现在就到此为止。 如果其他人有更好的解决方案,请分享。 我希望这对某人有帮助。谢谢大家!相关问题 更多 >
编程相关推荐