Tensorflow:contrib-GANEstim的身份函数

2024-06-18 11:11:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用tf.contrib.gan.estimator.GANEstimator(https://www.tensorflow.org/api_docs/python/tf/contrib/gan/estimator/GANEstimator)训练一个Wasserstein生成性对抗网络。但我没有使用生成器网络将噪声映射到图像,而是使用细化器将“坏图像”映射到“好图像”。因此,输入和输出维度是相同的

我只想对批评家网络进行交叉检查,看看梯度惩罚损失和批评家损失是否收敛。生成器网络应该只传递图像,但我不知道如何为GANEstimator提供一个生成器函数来实现这一点

当返回输入时,我得到一个AssertionError: assert variables_to_train

将输入与1相乘时 一些非常奇怪的事情正在发生,比如梯度惩罚损失几乎为零,除了一些尖峰,损失是发散的:

def generator_network(self, images, mode):
    noop = tf.get_variable("noop", shape=[], initializer=tf.initializers.ones())
    return images * noop

Tags: https图像网络tftensorflowwwwcontrib梯度