如何确保tf.control_依赖项()当我训练具有多个网络的GANlike图时?

2024-10-08 18:25:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我想我可以把这个问题概括为,“当我有两个唯一的网络时,如何使用批处理规范化?”在

我训练的基本上是一个GAN,鉴别器和生成器都有批量范数层。这有点不同,因为这两个网络都有各自的损耗函数,完全独立于另一个,这与普通的GAN框架不同。第二个网络基本上只是用来测量生成器在任务中的“错误程度”,但它们都应该完全独立地更新。在

我的网络都是在单独的gpu上定义的,因为它们相当大。在

我将网络放置在每个GPU上,并在以下代码中分配依赖项:

with tf.device("/gpu:0"):
    pred = uNet2D(X, BETA[j], KERNEL_SIZE, is_training)
    cost = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.reshape(Y,[-1]),logits=tf.reshape(pred,[-1])))

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
            optimizer = tf.train.AdamOptimizer(learning_rate=LR[i]).minimize(W*cost)    



with tf.device("/gpu:1"):
    attention = attentionNetwork(X_ATTN, BETA[j], KERNEL_SIZE, is_training)
    cost_d = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_ATTN,logits=attention))

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer_d = tf.train.AdamOptimizer(learning_rate=0.2*LR[i]).minimize(cost_d)

不过,我有点担心,因为我的张量板图图像表明,uNet(我的生成器)的输出是一个输入,梯度用于更新attentionenetwork(我的鉴别器)。在

有人能帮我决定怎样构造这些积木吗?我还担心的是,优化attentionNetwork需要包含uNet2D()和cost on中定义的占位符gpu:0。在

谢谢!我的张量板图表附在下面。在

enter image description here

编辑:当我在没有Batch Norm的情况下运行这个程序,因此没有control\u dependencies()时,我得到了一个看起来像这样的Tensorboard,我很确定这是我想要的。在

enter image description here


Tags: 网络gpu定义devicetfwithdependenciesupdate
2条回答

像所有人一样切换到Pythorch,https://discuss.pytorch.org/他们甚至有开发者回答问题的论坛。在

我认为您正在使用tf.layers.batch_normalization函数,因为您是将来自tf.GraphKeys.UPDATE_OPS的更新操作添加为依赖项。在

代码的问题在于,您使用整个tf.GraphKeys.UPDATE_OPS集合来定义包含批处理规范更新的依赖项。每当使用tf.layers.batch_normalization创建批处理规范层时,该层的更新操作将添加到tf.GraphKeys.UPDATE_OPS集合中。因此,在定义uNet2D的第一个代码块中,optimizer将只有uNet2D的批处理规范更新集合作为依赖项。但是,当您创建attentionNetwork时,tf.GraphKeys.UPDATE_OPS会添加更多的批量规范更新。因此,对于attentionNetwork优化器的依赖关系实际上包括所有批处理规范更新,包括uNet2D模型的更新。在

为了解决这个问题,您需要过滤每个模型的批量定额更新。如果使用作用域创建每个模型,例如:

with tf.variable_scope('unet2d'):
    # ... creation of the model uNet2D..

with tf.variable_scope('attention_network'):
    # ... creation of the model attentionNetwork..

可以使用范围筛选每个模型的批量规范更新:

^{pr2}$

一个附加说明:确保在优化器中使用了正确的变量。由于您没有将变量传递给最小化函数中的参数var_list,因此模型将实际收集tf.GraphKeys.TRAINABLE_VARIABLES集合中的所有变量。在

相关问题 更多 >

    热门问题