使用td.cond在训练期间会导致减少吞吐量。

2024-10-01 13:35:24 发布

您现在位置:Python中文网/ 问答频道 /正文

在使用resnet50进行imagenet训练的过程中,我们使用LAR更新学习率并计算训练每个步骤的LR。培训的吞吐量约为5500。为此,我们打算每隔几步优化和计算LR操作,以提高吞吐量。在原始代码中,我们每一步都执行compute_lr计算

我修改了代码,如下所示:

  • Global_step是用来观察训练哪一步的张量
  • 2是一个常数,表示每两步计算一次lr

守则:

def compute_lr()
    coumpte_lr 
       ...
    stored_lr
       ...
    return lr
def get_larsvalue()
    get_stored_lr
       ...
    return lr

tf.cond(tf.cast(tf.equal(tf.mod(gg,2),0),tf.bool),lambda:self.compute_lr(),lambda: self.get_larsvalue())

但是在我修改代码之后,吞吐量下降了。经过分析,我认为这是因为tf.cond不是一个懒惰的操作,它将执行两个分支,这显然不是我想要的。我现在不知道如何编写代码来完成我的想法,请大家帮忙


Tags: lambda代码selfgetreturntfdef吞吐量