多输入多输出的变分自动编码器

2024-10-01 00:21:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经在Keras中建立了一个自动编码器,它接受多个输入和相同数量的输出,我想转换成一个可变的自动编码器。我很难把输入输出差的损失和变分部分的损失结合起来。在

我想要达到的目标:

自动编码器应用于包含数字和分类数据的数据集。为此,我规范化数字列并对分类列进行1-hot-encode。由于得到的分类向量和数值向量需要不同的损失函数(数值的平均平方误差,分类列的分类交叉熵),与小数值列相比,非常大的1热编码向量将主导损失,我决定把每一列作为自己的输入向量。因此,我的自动编码器接受一组输入向量,生成相同数量和形状的输出向量。在

到目前为止我所做的:

这是两个数字输入和两个分类输入的设置,具有20和30宽的1热编码:

encWidth = 3
## Encoder
x = Concatenate(axis=1)([ Input(1,),Input(1,),Input(20,),Input(30,) ]) #<-configurable
x = Dense( 32, activation="relu")(x)
layEncOut = Dense( encWidth, activation="linear")(x)

layDecIn = Input( encWidth, name="In_Encoder" )
x = Dense( 32, activation="relu")(layDecIn)
layDecOut = [ outLayer(x) for outLayer in C.layOutputs ]

encoder = Model(C.layInputs, layEncOut, name="encoder")
decoder = Model( layDecIn, layDecOut, name="decoder" )

AE = Model(C.layInputs, decoder(encoder(C.layInputs)), name="autoencoder")
AE.compile(optimizer="adam", 
           loss=['mean_squared_error', 'mean_squared_error',
                 'categorical_crossentropy', 'categorical_crossentropy',], #<-configurable
                 loss_weights=[1.0, 1.0, 1.0, 1.0] #<-configurable
          )

这个例子是静态的,但是在我的实现中,数字和分类字段是可配置的,因此输入、损失函数的类型和损失权重应该可以从存储数据集中原始列的对象中配置。在

^{pr2}$

这里C是一个类的实例,它有输入层和丢失函数/权重,这取决于我想在自动编码器中有哪些列。在

我的问题是:

我已经把设置扩展到一个可变的自动编码器,有一个平均值和标准差的潜在层。在

encWidth = 2

## Encoder
x = Concatenate(axis=1)(C.layInputs)
x = Dense( 32, activation="relu")(x)

### variational part
z_mean = Dense(encWidth, name='z_mean', activation=lrelu)(x)
z_log_var = Dense(encWidth, name='z_log_var', activation=lrelu)(x)
z = Lambda(sampling, name='z')([z_mean, z_log_var])

## Decoder
layDecodeInput = Input( encWidth, name="In_Encoder" )
x = Dense( 32, activation="relu")(layDecodeInput)
layOutDecoder = [ outLayer(x) for outLayer in C.layOutputs ]

### build the encoder model
vEncoder = Model(C.layInputs, [z_mean, z_log_var, z], name='v_encoder')

### build the decoder model
vDecoder = Model( layDecodeInput, layOutDecoder, name="v_decoder" )

## Autoencoder
vAE = Model(C.layInputs, vDecoder(vEncoder(C.layInputs)[2]))
vae_loss = variational_loss(z_mean, z_log_var)
vAE.compile(optimizer="adam",
            loss=vae_loss)

现在,我需要一个自定义的误差函数,它将输入和输出之间的差异的损失(如前一个例子中所示)与方差部分的损失相结合;这就是我目前为止所想到的:

def variational_loss(z_mean, z_log_var, varLossWeight=1.):    
    def lossFct(yTrue, yPred):       

        var_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var))

        lossFunctions = [getattr(losses, "mean_squared_error") for losses in  C.losses]
        ac_loss = [
          lossFkt(yTrue, yPred) * lossWeigt for
          yt, yp, lossFkt, lossWeigt in zip(yTrue, yPred, lossFunctions, C.lossWeights) ]

        loss =  K.mean( ac_loss + [ kl_loss * varLossWeight ] )
        return loss
    return lossFct

所以这是一个生成函数,它返回一个接受yTrue和yPredicted的函数,但在变分部分工作。for循环应该循环所有输入和相应的输出,并使用适当的损失函数(数值的均方误差或分类特征的分类交叉熵)对它们进行比较

但是显然for循环不允许循环遍历输入向量集并将它们与输出向量集进行比较;我得到了一个错误

Tensor objects are only iterable when eager execution is enabled. To iterate over this tensor use tf.map_fn.

{duni>如何在不同的输入函数和不同的损失函数中使用不同的函数?在


Tags: 函数nameloginputvar分类编码器mean
1条回答
网友
1楼 · 发布于 2024-10-01 00:21:02

我认为在网络中添加aKL发散层会更简单,它可以处理VAE损耗。你可以这样做,(β是vae损失的重量):

import keras.backend as K
from keras.layers import Layer

class KLDivergenceLayer(Layer):

    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

    def __init__(self, beta=.5, *args, **kwargs):
        self.is_placeholder = True
        self.beta = beta
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):

        mu, log_var = inputs

        kl_batch = - self.beta * K.sum(1 + log_var -
                                K.square(mu) -
                                K.exp(log_var), axis=-1)

        self.add_loss(K.mean(kl_batch), inputs=inputs)

        return inputs

然后,在计算平均值和log var之后,可以在代码中添加以下行:

^{pr2}$

这一层是一个身份层,将KL损耗加到最终损耗中。那么你最后的损失可能就是你在上面用过的那个。在

我在路易斯·C·蒂奥的帖子里找到的:https://tiao.io/post/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/

相关问题 更多 >