我试图以混合精度的方式在TensorFlow中实现一个WGAN-GP模型。正如the TensorFlow mixed-precision guide中所解释的,我们需要调整损失以避免反向传播期间的下溢/溢出。但是对于梯度惩罚,我们必须在用于计算损失的梯度带内计算梯度(梯度惩罚),所以我们也必须管理这些梯度的缩放,不是吗
我已经尝试为梯度惩罚实现我自己的梯度缩放,遵循this process,并为损失缩放保留TensorFlow LossScaleOptimizer(如here所述),但它会导致数值不稳定,结果很糟糕
对于在TensorFlow中使用混合精度的梯度惩罚,您有什么建议吗
以下是批评家(鉴别器)架构:
def Discriminator(lrelu_slope=0.2, kernel_initializer=random_normal_initializer(0,0.02), name=None):
def Downsample(n_filters, inst_norm=True, instNorm_mean_f32=False, instNorm_mean_abs_dev_f32=False):
model = Sequential([
ZeroPadding3D(1),
Conv3D(filters=n_filters, kernel_size=4, strides=2, padding="valid", use_bias=False, kernel_initializer=kernel_initializer)
])
if inst_norm:
model.add(InstanceNormalization(mean_f32=instNorm_mean_f32, mean_abs_dev_f32=instNorm_mean_abs_dev_f32))
model.add(LeakyReLU(lrelu_slope))
return model
return Sequential([
Input(shape=(182,218,182,1)),
Downsample(64, inst_norm=False),
Downsample(128, instNorm_mean_f32=True, instNorm_mean_abs_dev_f32=True),
Downsample(256, instNorm_mean_f32=True, instNorm_mean_abs_dev_f32=True),
Sequential([
Conv3D(filters=512, kernel_size=3, strides=1, padding="same", use_bias=False, kernel_initializer=kernel_initializer),
InstanceNormalization(name=f"{name}__conv_block__instNorm", mean_f32=True, mean_abs_dev_f32=True),
LeakyReLU(lrelu_slope),
]),
Conv3D(filters=1, kernel_size=3, strides=1, padding="same", use_bias=True, kernel_initializer=kernel_initializer),
Lambda(lambda x: k_mean(x, axis=(1,2,3)), dtype="float32")
])
下面是批评家培训课程:
class Critic_trainer:
def __init__(self, critic, generator, lbda_gp, optimizer, init_scale_factor=pow(2,15), moving_factor=2.0, N=200):
self.critic = critic
self.generator = generator
self.lbda_gp = lbda_gp
self.optimizer = optimizer
self.s = Variable(init_scale_factor, dtype="float32", trainable=False)
self.m = moving_factor
self.N = N
self.n = Variable(0, trainable=False)
@tf_function(input_signature=(TensorSpec(shape=[BATCH_SIZE]+IMAGE_SHAPE, dtype="float32"),
TensorSpec(shape=[BATCH_SIZE]+IMAGE_SHAPE, dtype="float32")))
def train(self, images1, images2):
"""
images1 : batch of images from critic's domain
images2 : batch of images from the other domain which will be translated to the critic's domain
"""
fake_images = self.generator(images2, training=False)
with GradientTape() as tape:
gp = self.gradient_penalty(images1, fake_images)
if gp != -1.0:
mean_critic_real = k_mean(self.critic(images1, training=True))
mean_critic_fakes = k_mean(self.critic(fake_images, training=True))
critic_loss = mean_critic_fakes - mean_critic_real + self.lbda_gp*gp
critic_loss = self.optimizer.get_scaled_loss(critic_loss)
else:
mean_critic_real = 0.0
mean_critic_fakes = 0.0
critic_loss = 0.0
grads = tape.gradient(critic_loss, self.critic.trainable_variables)
grads = self.optimizer.get_unscaled_gradients(grads)
self.optimizer.apply_gradients(zip(grads, self.critic.trainable_variables))
return mean_critic_real, mean_critic_fakes, gp
@tf_function(input_signature=(TensorSpec(shape=[BATCH_SIZE]+IMAGE_SHAPE, dtype="float32"),
TensorSpec(shape=[BATCH_SIZE]+IMAGE_SHAPE, dtype="float32")))
def gradient_penalty(self, real_images, fake_images):
# Computing interpolated images
alpha = uniform((BATCH_SIZE, 1, 1, 1, 1),0,1)
inter_images = fake_images + alpha*(real_images-fake_images)
# Critics of interpolated images
with GradientTape() as gp_tape:
gp_tape.watch(inter_images)
critics = self.critic(inter_images, training=True) * self.s
grads = gp_tape.gradient(critics, inter_images) / self.s
if reduce_all(is_finite(grads)):
self.n.assign_add(1)
if self.n==self.N:
self.n.assign(0)
self.s.assign(self.s * self.m)
else:
self.n.assign(0)
self.s.assign(self.s / self.m)
return -1.0
norms = sqrt(k_sum(square(grads), axis=(1,2,3,4)))
gps = square(norms-1)
return k_mean(gps)
目前没有回答
相关问题 更多 >
编程相关推荐