多gpu tensorflow代码中批量规范化参数的更新?

2024-09-09 13:06:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我写了一个多gpu的Cnn代码

this link 他们在第249行评论道

# Retain the Batch Normalization updates operations only from the # final tower. Ideally, we should grab the updates from all towers # but these stats accumulate extremely fast so we can ignore the # other stats from the other towers without significant detriment.

但更新批次定额的代码(第253行):

with tf.device('/gpu:%d' % i):
.
.
.
  batchnorm_updates = 
  tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION,scope)

对所有塔(GPU)完成 那么这个更新的正确位置是什么呢?在


Tags: the代码fromgputfstats评论link
1条回答
网友
1楼 · 发布于 2024-09-09 13:06:23

我想你对代码的理解不正确。在

作为code中的for loop

for i in range(FLAGS.num_gpus):
  with tf.device('/gpu:%d' % i):
    with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:

      with slim.arg_scope([slim.variables.variable], device='/cpu:0'):

      ......

      batchnorm_updates = tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION,
                                            scope)

在每个for loop之后,batchnorm_updates将被替换,因此当for loop完成时,它只包含最后一个塔的批处理规范化更新操作。在

相关问题 更多 >