我正在处理一个生成心跳的虚拟示例,希望首先使用VAE对心跳进行编码,然后使用一个简单的分类器
问题是,当我将beta值增加到0.01以上时,重建将变得毫无意义(请参见第一幅图)。 当beta值较低时,我得到一个正常的自动编码器输出,没有解纠缠(第二幅图像)。
我相信问题可能在我的KL散度或VAE损失函数中,但我似乎找不到它。 在我的编码器中,我进行重新参数化:
enc = self.encoder(x,batch_size, x_lenghts)
mu = self.enc2mean(enc)
logv = self.enc2logv(enc)
std = torch.exp(0.5*logv)
z = torch.randn([batch_size,1, self.encoder_hidden_sizes[-1] * (int(self.bidirectional)+1)]).to(self.device)
z = z * std + mu
我将VAE损失定义为:
def VAE_loss(x, reconstruction, mu, logvar, batch_size, latent_dim, beta=0):
mse = F.mse_loss(x, reconstruction)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
KLD /= (batch_size * latent_dim)
return mse + beta*KLD
复制结果的完整独立代码是here
任何见解都将不胜感激
目前没有回答
相关问题 更多 >
编程相关推荐