更新模型一部分的权重(nn.模块)

2024-10-01 15:47:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我在构建一个松散地基于CycleGAN体系结构的网络时遇到了一个问题

我将它的所有组件都放在一个nn.Module

from torch import nn

from classes.EncoderDecoder import EncoderDecoder
from classes.Discriminator import Discriminator

class CycleGAN(nn.Module):
    def __init__(self):
        super(CycleGAN, self).__init__()
        self.encdec1 = EncoderDecoder(encoder_in_channels=3)
        self.encdec2 = EncoderDecoder(encoder_in_channels=3)
        self.disc = Discriminator()
        

    def forward(self, images, images_bw):

        disc_color = self.disc(images) # I want the Discriminator to be trained here
        disc_bw = self.disc(images_bw) # I want the Discriminator to be trained here

        decoded1 = self.encdec1(images_bw) # EncoderDecoder forward pass
        decoded2 = self.encdec2(decoded1)

        decoded_disc = self.disc(decoded1)  # I don't want to train the Discriminator here, 
                                            # only the EncoderDecoder should be trained based
                                            # on this Discriminator's result

        return [disc_color, disc_bw, decoded1, decoded2, decoded_disc]

这就是我初始化这个模块、丢失函数和优化器的方式

c_gan = CycleGAN().to('cuda', dtype=float32, non_blocking=True)

l2_loss = MSELoss().to('cuda', dtype=float32).train()
bce_loss = BCELoss().to('cuda', dtype=float32).train()

optimizer_gan = Adam(c_gan.parameters(), lr=0.00001)

这就是我在训练循环中训练网络的方式

c_gan.zero_grad()
optimizer_gan.zero_grad()

disc_color, disc_bw, decoded1, decoded2, decoded_disc = c_gan(images, images_bw)

loss_true = bce_loss(disc_color, label_true)
loss_false = bce_loss(disc_bw, label_false)
disc_loss = loss_true + loss_false
disc_loss.backward()

decoded_loss = l2_loss(decoded2, images_bw)
decoded_disc_loss = bce_loss(decoded_disc, label_true) # This is where the loss for that Discriminator forward pass is calculated
both_decoded_losses = decoded_loss + decoded_disc_loss
both_decoded_losses.backward()
optimizer_gan.step()

问题

我不想基于EncoderDecoder -> Discriminator向前传递来训练Discriminator模块。然而,我确实希望根据images -> Discriminatorimages_bw -> Discriminator向前传球来训练它

  • 对于我的CycleGAN模块,是否可以只使用一个优化器来实现这一点?
  • 我可以在优化器的.step()期间冻结Discriminator吗?

我将感谢任何帮助


Tags: thetoselfcolordiscbwimagesdecoded

热门问题