擅长:python、mysql、java
<p>我认为您正在使用<a href="https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization" rel="nofollow noreferrer">tf.layers.batch_normalization</a>函数,因为您是将来自<code>tf.GraphKeys.UPDATE_OPS</code>的更新操作添加为依赖项。在</p>
<p>代码的问题在于,您使用整个<code>tf.GraphKeys.UPDATE_OPS</code>集合来定义包含批处理规范更新的依赖项。每当使用<code>tf.layers.batch_normalization</code>创建批处理规范层时,该层的更新操作将添加到<code>tf.GraphKeys.UPDATE_OPS</code>集合中。因此,在定义uNet2D的第一个代码块中,<code>optimizer</code>将只有uNet2D的批处理规范更新集合作为依赖项。但是,当您创建attentionNetwork时,<code>tf.GraphKeys.UPDATE_OPS</code>会添加更多的批量规范更新。因此,对于attentionNetwork优化器的依赖关系实际上包括所有批处理规范更新,包括uNet2D模型的更新。在</p>
<p>为了解决这个问题,您需要过滤每个模型的批量定额更新。如果使用作用域创建每个模型,例如:</p>
<pre><code>with tf.variable_scope('unet2d'):
# ... creation of the model uNet2D..
with tf.variable_scope('attention_network'):
# ... creation of the model attentionNetwork..
</code></pre>
<p>可以使用范围筛选每个模型的批量规范更新:</p>
^{pr2}$
<p>一个附加说明:确保在优化器中使用了正确的变量。由于您没有将变量传递给最小化函数中的参数<code>var_list</code>,因此模型将实际收集<code>tf.GraphKeys.TRAINABLE_VARIABLES</code>集合中的所有变量。在</p>