编码器输入与解码器输出不同

2024-05-13 12:02:56 发布

您现在位置:Python中文网/ 问答频道 /正文

大家好,我正在使用来自machinecurve的代码

endecode部件具有此体系结构。输入为28x28大小的图像:

i       = Input(shape=input_shape, name='encoder_input')
cx      = Conv2D(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(i)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
x       = Flatten()(cx)
x       = Dense(20, activation='relu')(x)
x       = BatchNormalization()(x)
mu      = Dense(latent_dim, name='latent_mu')(x)
sigma   = Dense(latent_dim, name='latent_sigma')(x)

解码部分如下,它尝试反转代码部分中定义的层:

d_i   = Input(shape=(latent_dim, ), name='decoder_input')
x     = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x     = BatchNormalization()(x)
x     = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx    = Conv2DTranspose(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(x)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
o     = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)

如下所示,编码器输入必须与解码器输出相同:

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 17298104  
_________________________________________________________________
decoder (Model)              (None, 32, 32, 1)         43457025  
=================================================================
Total params: 60,755,129
Trainable params: 60,739,217
Non-trainable params: 15,912

然后,当模型经过训练时,我们会出现以下错误:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-46-44d4cd644e8f> in <module>
      3 
      4 # Train autoencoder
----> 5 vae.fit(input_train, input_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)

~\.conda\envs\keypoints\lib\site-packages\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
   1237                                         steps_per_epoch=steps_per_epoch,
   1238                                         validation_steps=validation_steps,
-> 1239                                         validation_freq=validation_freq)
   1240 
   1241     def evaluate(self,

~\.conda\envs\keypoints\lib\site-packages\keras\engine\training_arrays.py in fit_loop(model, fit_function, fit_inputs, out_labels, batch_size, epochs, verbose, callbacks, val_function, val_inputs, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq)
    194                     ins_batch[i] = ins_batch[i].toarray()
    195 
--> 196                 outs = fit_function(ins_batch)
    197                 outs = to_list(outs)
    198                 for l, o in zip(out_labels, outs):

~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\keras\backend.py in __call__(self, inputs)
   3738         value = math_ops.cast(value, tensor.dtype)
   3739       converted_inputs.append(value)
-> 3740     outputs = self._graph_fn(*converted_inputs)
   3741 
   3742     # EagerTensor.numpy() will often make a copy to ensure memory safety.

~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\function.py in __call__(self, *args, **kwargs)
   1079       TypeError: For invalid positional/keyword argument combinations.
   1080     """
-> 1081     return self._call_impl(args, kwargs)
   1082 
   1083   def _call_impl(self, args, kwargs, cancellation_manager=None):

~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\function.py in _call_impl(self, args, kwargs, cancellation_manager)
   1119       raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
   1120           list(kwargs.keys()), list(self._arg_keywords)))
-> 1121     return self._call_flat(args, self.captured_inputs, cancellation_manager)
   1122 
   1123   def _filtered_call(self, args, kwargs):

~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1222     if executing_eagerly:
   1223       flat_outputs = forward_function.call(
-> 1224           ctx, args, cancellation_manager=cancellation_manager)
   1225     else:
   1226       gradient_name = self._delayed_rewrite_functions.register()

~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    509               inputs=args,
    510               attrs=("executor_type", executor_type, "config_proto", config),
--> 511               ctx=ctx)
    512         else:
    513           outputs = execute.execute_with_cancellation(

~\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     65     else:
     66       message = e.message
---> 67     six.raise_from(core._status_to_exception(e.code, message), None)
     68   except TypeError as e:
     69     keras_symbolic_tensors = [

~\.conda\envs\keypoints\lib\site-packages\six.py in raise_from(value, from_value)

InvalidArgumentError:  Incompatible shapes: [100352] vs. [131072]
     [[node gradients/loss/decoder_loss/kl_reconstruction_loss/mul_1_grad/BroadcastGradientArgs (defined at C:\Users\XXXXX\.conda\envs\keypoints\lib\site-packages\tensorflow_core\python\framework\ops.py:1751) ]] [Op:__inference_keras_scratch_graph_22124]

Function call stack:
keras_scratch_graph

请问你对如何解决这个问题有什么想法

我还定义:

# Define sampling with reparameterization trick
def sample_z(args):
    mu, sigma = args
    batch     = K.shape(mu)[0]
    dim       = K.int_shape(mu)[1]
    eps       = K.random_normal(shape=(batch, dim))
    return mu + K.exp(sigma / 2) * eps

# Use reparameterization trick to ensure correct gradient
z       = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])

编码器将被定义为:

encoder = Model(i, [mu, sigma, z], name='encoder')

架构是:

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
encoder_input (InputLayer)      (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 14, 14, 128)  3328        encoder_input[0][0]              
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 14, 14, 128)  512         conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 7, 7, 256)    819456      batch_normalization_25[0][0]     
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 7, 7, 256)    1024        conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 4, 4, 512)    3277312     batch_normalization_26[0][0]     
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 4, 4, 512)    2048        conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 2, 2, 1024)   13108224    batch_normalization_27[0][0]     
__________________________________________________________________________________________________
batch_normalization_28 (BatchNo (None, 2, 2, 1024)   4096        conv2d_13[0][0]                  
__________________________________________________________________________________________________
flatten_4 (Flatten)             (None, 4096)         0           batch_normalization_28[0][0]     
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 20)           81940       flatten_4[0][0]                  
__________________________________________________________________________________________________
batch_normalization_29 (BatchNo (None, 20)           80          dense_7[0][0]                    
__________________________________________________________________________________________________
latent_mu (Dense)               (None, 2)            42          batch_normalization_29[0][0]     
__________________________________________________________________________________________________
latent_sigma (Dense)            (None, 2)            42          batch_normalization_29[0][0]     
__________________________________________________________________________________________________
z (Lambda)                      (None, 2)            0           latent_mu[0][0]                  
                                                                 latent_sigma[0][0]               
==================================================================================================
Total params: 17,298,104

类似地,解码器部分被定义为:

decoder = Model(d_i, o, name='decoder')

解码器的体系结构是:

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
decoder_input (InputLayer)   (None, 2)                 0         
_________________________________________________________________
dense_8 (Dense)              (None, 4096)              12288     
_________________________________________________________________
batch_normalization_30 (Batc (None, 4096)              16384     
_________________________________________________________________
reshape_4 (Reshape)          (None, 2, 2, 1024)        0         
_________________________________________________________________
conv2d_transpose_10 (Conv2DT (None, 4, 4, 1024)        26215424  
_________________________________________________________________
batch_normalization_31 (Batc (None, 4, 4, 1024)        4096      
_________________________________________________________________
conv2d_transpose_11 (Conv2DT (None, 8, 8, 512)         13107712  
_________________________________________________________________
batch_normalization_32 (Batc (None, 8, 8, 512)         2048      
_________________________________________________________________
conv2d_transpose_12 (Conv2DT (None, 16, 16, 256)       3277056   
_________________________________________________________________
batch_normalization_33 (Batc (None, 16, 16, 256)       1024      
_________________________________________________________________
conv2d_transpose_13 (Conv2DT (None, 32, 32, 128)       819328    
_________________________________________________________________
batch_normalization_34 (Batc (None, 32, 32, 128)       512       
_________________________________________________________________
decoder_output (Conv2DTransp (None, 32, 32, 1)         1153      
=================================================================
Total params: 43,457,025
Trainable params: 43,444,993
Non-trainable params: 12,032

最后,我们把它们放在一起:

# =================
# VAE as a whole
# =================
# Instantiate VAE
vae_outputs = decoder(encoder(i)[2])
vae         = Model(i, vae_outputs, name='vae')

Tags: nameinselfnoneinputsizebatchargs
1条回答
网友
1楼 · 发布于 2024-05-13 12:02:56

这是由于解码器的输出形状导致的问题。。。您只需使用以下方法更改解码器的最后一层即可解决此问题:

Conv2D(filters=num_channels, kernel_size=5, activation='sigmoid', name='decoder_output')

以下是完整的代码:

num_channels = 1
latent_dim = 2
input_shape = (28,28,1)

i       = Input(shape=input_shape, name='encoder_input')
cx      = Conv2D(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(i)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
x       = Flatten()(cx)
x       = Dense(20, activation='relu')(x)
x       = BatchNormalization()(x)

mu      = Dense(latent_dim, name='latent_mu')(x)
sigma   = Dense(latent_dim, name='latent_sigma')(x)

conv_shape = K.int_shape(cx)
d_i   = Input(shape=(latent_dim, ), name='decoder_input')
x     = Dense(np.prod(conv_shape[1:]), activation='relu')(d_i)
x     = BatchNormalization()(x)
x     = Reshape(conv_shape[1:])(x)
cx    = Conv2DTranspose(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(x)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
o     = Conv2D(filters=num_channels, kernel_size=5, activation='sigmoid', name='decoder_output')(cx)

采样层:

def sample_z(args):
    mu, sigma = args
    batch     = K.shape(mu)[0]
    dim       = K.int_shape(mu)[1]
    eps       = K.random_normal(shape=(batch, dim))
    return mu + K.exp(sigma / 2) * eps

# Use reparameterization trick to ensure correct gradient
z       = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])

最终VAE:

encoder = Model(i, [mu, sigma, z], name='encoder')
decoder = Model(d_i, o, name='decoder')
vae_outputs = decoder(encoder(i)[2])
vae         = Model(i, vae_outputs, name='vae')

总结:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_input (InputLayer)   [(None, 28, 28, 1)]       0         
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 17298104  
_________________________________________________________________
decoder (Model)              (None, 28, 28, 1)         43459073  
=================================================================

如您所见,输入和输出形状现在匹配

相关问题 更多 >