我正在使用Tensorflow DCGAN实现指南中提供的代码编写一个自定义训练循环。我想在训练循环中添加回调。在Keras中,我知道我们将它们作为一个参数传递给'fit'方法,但是找不到关于如何在定制训练循环中使用这些回调的资源。我添加了来自Tensorflow文档的定制训练循环的代码:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
这没道理。在
没有什么有意义的方法可以像解释}那样解释鉴别器/生成器丢失。两个磁盘/发电机损耗都与另一个相关,因此没有明确的停止标准。理想情况下,损失将达到Nash Equilibrium,但这在现实中不太可能。这不是对堆栈溢出的讨论,而是对https://stats.stackexchange.com的讨论。在
log loss
或{相关问题 更多 >
编程相关推荐