Tensorflow混合精度二阶梯度标度

2024-10-03 04:39:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图以混合精度的方式在TensorFlow中实现一个WGAN-GP模型。正如the TensorFlow mixed-precision guide中所解释的,我们需要调整损失以避免反向传播期间的下溢/溢出。但是对于梯度惩罚,我们必须在用于计算损失的梯度带内计算梯度(梯度惩罚),所以我们也必须管理这些梯度的缩放,不是吗

我已经尝试为梯度惩罚实现我自己的梯度缩放,遵循this process,并为损失缩放保留TensorFlow LossScaleOptimizer(如here所述),但它会导致数值不稳定,结果很糟糕

对于在TensorFlow中使用混合精度的梯度惩罚,您有什么建议吗

以下是批评家(鉴别器)架构:

def Discriminator(lrelu_slope=0.2, kernel_initializer=random_normal_initializer(0,0.02), name=None):
    
    def Downsample(n_filters, inst_norm=True, instNorm_mean_f32=False, instNorm_mean_abs_dev_f32=False):
        model = Sequential([
            ZeroPadding3D(1),
            Conv3D(filters=n_filters, kernel_size=4, strides=2, padding="valid", use_bias=False, kernel_initializer=kernel_initializer)
        ])
        if inst_norm:
            model.add(InstanceNormalization(mean_f32=instNorm_mean_f32, mean_abs_dev_f32=instNorm_mean_abs_dev_f32))
        model.add(LeakyReLU(lrelu_slope))
        return model
    
    return Sequential([
        Input(shape=(182,218,182,1)),
        Downsample(64, inst_norm=False),
        Downsample(128, instNorm_mean_f32=True, instNorm_mean_abs_dev_f32=True),
        Downsample(256, instNorm_mean_f32=True, instNorm_mean_abs_dev_f32=True),
        Sequential([
            Conv3D(filters=512, kernel_size=3, strides=1, padding="same", use_bias=False, kernel_initializer=kernel_initializer),
            InstanceNormalization(name=f"{name}__conv_block__instNorm", mean_f32=True, mean_abs_dev_f32=True),
            LeakyReLU(lrelu_slope),  
        ]),
        Conv3D(filters=1, kernel_size=3, strides=1, padding="same", use_bias=True, kernel_initializer=kernel_initializer),
        Lambda(lambda x: k_mean(x, axis=(1,2,3)), dtype="float32")
    ])

下面是批评家培训课程:

class Critic_trainer:
    
    def __init__(self, critic, generator, lbda_gp, optimizer, init_scale_factor=pow(2,15), moving_factor=2.0, N=200):
        self.critic = critic
        self.generator = generator
        self.lbda_gp = lbda_gp
        self.optimizer = optimizer
        self.s = Variable(init_scale_factor, dtype="float32", trainable=False)
        self.m = moving_factor
        self.N = N
        self.n = Variable(0, trainable=False)
        
    @tf_function(input_signature=(TensorSpec(shape=[BATCH_SIZE]+IMAGE_SHAPE, dtype="float32"),
                                  TensorSpec(shape=[BATCH_SIZE]+IMAGE_SHAPE, dtype="float32")))
    def train(self, images1, images2):
        """
        images1 : batch of images from critic's domain
        images2 : batch of images from the other domain which will be translated to the critic's domain
        """
        fake_images = self.generator(images2, training=False)
        with GradientTape() as tape:
            gp = self.gradient_penalty(images1, fake_images)
            if gp != -1.0:
                mean_critic_real = k_mean(self.critic(images1, training=True))
                mean_critic_fakes = k_mean(self.critic(fake_images, training=True))
                critic_loss = mean_critic_fakes - mean_critic_real + self.lbda_gp*gp
                critic_loss = self.optimizer.get_scaled_loss(critic_loss)
            else:
                mean_critic_real = 0.0
                mean_critic_fakes = 0.0
                critic_loss = 0.0
        grads = tape.gradient(critic_loss, self.critic.trainable_variables)
        grads = self.optimizer.get_unscaled_gradients(grads)
        self.optimizer.apply_gradients(zip(grads, self.critic.trainable_variables))
        return mean_critic_real, mean_critic_fakes, gp
    
    @tf_function(input_signature=(TensorSpec(shape=[BATCH_SIZE]+IMAGE_SHAPE, dtype="float32"),
                                  TensorSpec(shape=[BATCH_SIZE]+IMAGE_SHAPE, dtype="float32")))
    def gradient_penalty(self, real_images, fake_images):
        # Computing interpolated images
        alpha = uniform((BATCH_SIZE, 1, 1, 1, 1),0,1)        
        inter_images = fake_images + alpha*(real_images-fake_images)
        
        # Critics of interpolated images
        with GradientTape() as gp_tape:
            gp_tape.watch(inter_images)
            critics = self.critic(inter_images, training=True) * self.s
        grads = gp_tape.gradient(critics, inter_images) / self.s
        if reduce_all(is_finite(grads)):
            self.n.assign_add(1)
            if self.n==self.N:
                self.n.assign(0)
                self.s.assign(self.s * self.m)
        else:
            self.n.assign(0)
            self.s.assign(self.s / self.m)
            return -1.0
    
        norms = sqrt(k_sum(square(grads), axis=(1,2,3,4)))
        gps = square(norms-1)
        return k_mean(gps)

Tags: devselffalsetrueabsmeankernel梯度