BetaVariational自动编码器无法解开

2024-06-14 02:59:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在处理一个生成心跳的虚拟示例,希望首先使用VAE对心跳进行编码,然后使用一个简单的分类器

问题是,当我将beta值增加到0.01以上时,重建将变得毫无意义(请参见第一幅图)。 当beta值较低时,我得到一个正常的自动编码器输出,没有解纠缠(第二幅图像)。 Beta=0.1Beta=0.01

我相信问题可能在我的KL散度或VAE损失函数中,但我似乎找不到它。 在我的编码器中,我进行重新参数化:

enc = self.encoder(x,batch_size, x_lenghts)
mu = self.enc2mean(enc)
logv = self.enc2logv(enc)
std = torch.exp(0.5*logv)
z = torch.randn([batch_size,1, self.encoder_hidden_sizes[-1] * (int(self.bidirectional)+1)]).to(self.device)
z = z * std + mu

我将VAE损失定义为:

def VAE_loss(x, reconstruction, mu, logvar, batch_size, latent_dim, beta=0):
    mse = F.mse_loss(x, reconstruction)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    KLD /= (batch_size * latent_dim)
    return mse + beta*KLD

复制结果的完整独立代码是here

任何见解都将不胜感激


Tags: selfencodersizebatchtorch编码器beta损失