如何快速检查哪些tensorflow变量在培训期间更新,哪些被冻结?

2024-09-30 01:33:42 发布

您现在位置:Python中文网/ 问答频道 /正文

我相信在很多情况下,我们需要冻结张量流图中的一些层,并保持其他层可训练。在

有没有一种方法可以快速检查网络是否按我们预期的那样进行了培训?例如,在训练期间,冻结层中的变量实际上不会更新。在

我使用以下方法冻结“ABC”范围内的所有变量:

    with slim.arg_scope(inception.inceptionb_v2_arg_scope()):
        with tf.variable_scope('ABC'):
          _, end_points = getattr(inception, 'inception_v2'(..., is_training = False))
                         ......
    trainables = [v for v in tf.trainable_variables() if 'ABC/' not in v.name]
    optimizer = tf.train.AdamOptimizer().minimize(loss, var_list= trainables)

有什么建议的方法可以快速确认这些变量在培训期间确实没有改变?在


Tags: 方法in网络tfwitharg情况v2
1条回答
网友
1楼 · 发布于 2024-09-30 01:33:42

您可以在几次迭代后检查它们:

frozen_variables = [v for v in tf.trainable_variables() if 'ABC/' in v.name]
tmp_frozen_variables_np = sess.run(frozen_variables)
# Training Code
assert np.allclose(tmp_frozen_variables_np, sess.run(frozen_variables))

但是,只要它们不在优化器的var列表中,就可以了。在

相关问题 更多 >

    热门问题