在Tensorflow 2.0的自定义训练循环中应用回调

2024-06-15 07:41:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用Tensorflow DCGAN实现指南中提供的代码编写一个自定义训练循环。我想在训练循环中添加回调。在Keras中,我知道我们将它们作为一个参数传递给'fit'方法,但是找不到关于如何在定制训练循环中使用这些回调的资源。我添加了来自Tensorflow文档的定制训练循环的代码:

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Tags: ofthetrueoutputtfgeneratorgendisc
1条回答
网友
1楼 · 发布于 2024-06-15 07:41:37

这没道理。在

没有什么有意义的方法可以像解释log loss或{}那样解释鉴别器/生成器丢失。两个磁盘/发电机损耗都与另一个相关,因此没有明确的停止标准。理想情况下,损失将达到Nash Equilibrium,但这在现实中不太可能。这不是对堆栈溢出的讨论,而是对https://stats.stackexchange.com的讨论。在

相关问题 更多 >