用python类构造Tensorflow中的GAN

2024-09-29 23:20:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我习惯于用Keras来设计我的秋裤。不过,为了满足特定的需要,我想将我的代码调整为Tensorflow。大多数使用Tensorflow的GANs实现使用一个类来实现GAN,然后使用一个用于鉴别器生成器的函数

看起来像这样:

class MyGAN():
    def __init__(self):
        # various initialisation

    def generator(self, n_examples):
        ### do some business and return n_examples generated.

        return G_output

    def discrimintator(self, images):
        ### do some business with the images

        return D_Prob, D_logits

事实上,这是非常好的。但是,我更喜欢每个部分(MyGAN、Generator、Discriminator)都是一个完整独立的类的设计。只初始化主节点:MyGAN,其余的由它自己处理。它为我提供了一个更简单的代码组织和相对容易的代码阅读。在

然而,我在一些设计模式上遇到了困难,有了Keras我可以使用“输入”层,它允许我从数据集的真实数据和生成器生成的伪数据中切换到鉴别器。只需几行代码就可以用Keras伪代码来揭示这个想法:

^{pr2}$

我的问题很简单,我怎样才能用Tensorflow重现这种代码结构?我有一些想法,但我不相信这些:

我可以用tf.变量然后使用加载函数在执行过程中分配它。问题是:对于每一个训练步骤,我似乎需要执行两个sess.运行()每个网络(D和G)。这显然是低效的。。。在

  • 对于发电机:

    • 1: 使用sess.run()调用生成G数据
    • 2: 用sess.run()调用加载D中的数据
    • 3: 使用另一个sess.run()调用计算损失
    • 4: 最后用最后一个sess.run()反向传播G
  • 对于鉴别器:

    • 1: 使用sess.run()调用生成G数据
    • 2: 用sess.run()调用加载D中的数据
    • 3: 使用sess.run()调用计算假数据的损失
    • 4: 使用sess.run()调用计算实际数据的损失
    • 5: 最后用最后一个sess.run()反向传播D

在我看来,这显然是低效的,我没有更好的主意。当然,我可以使用占位符,这会“隐藏”with with feed\dict的加载操作,但不会真正影响性能(我尝试过)。在

我的目标是:

  • 直接将G连接到D,并且能够避免调用G,只需将G和D直接连接起来。

  • 能够在从G或从数据批中获取数据时“切换D”。这样可以避免从GPU/CPU传输数据=>节省时间


Tags: 数据函数run代码selfreturntensorflowdef
1条回答
网友
1楼 · 发布于 2024-09-29 23:20:33

通过使用纯功能方法和使用可变范围重新应用网络,可以实现所需的设计结构。例如,此代码片段设置网络的真实/虚假部分:

with variable_scope.variable_scope(generator_scope) as gen_scope:
  generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
  discriminator_gen_outputs = discriminator_fn(generated_data,
                                               generator_inputs)
with variable_scope.variable_scope(dis_scope, reuse=True):
  discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)

考虑使用TensorFlow's TFGAN来避免重新创建GAN基础设施。These examples演示如何使用TFGAN创建各种类型的gan(使用,以及使用内置特性)。在

相关问题 更多 >

    热门问题