部分加载权值的迁移学习

2024-06-23 19:37:17 发布

您现在位置:Python中文网/ 问答频道 /正文

嗨,我有一个预训练的网络模型1,如下所示:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 3000, 200, 1) 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 3000, 200, 1) 26          input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 1500, 200, 1) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 1500, 200, 1) 26          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 750, 200, 1)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 750, 200, 1)  26          max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 375, 200, 1)  0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 375, 200, 1)  26          max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 187, 199, 1)  0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 128, 128, 1)  4321        max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 128, 16) 160         conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_6[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 64, 64, 16)   0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_8[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 64)   18496       max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_10[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 16, 16, 128)  73856       max_pooling2d_7[0][0]            
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_12[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 8, 8, 256)    295168      max_pooling2d_8[0][0]            
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 8, 8, 256)    590080      conv2d_14[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 16, 16, 256)  0           conv2d_15[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 384)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 16, 16, 128)  442496      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_16[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_17[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_18[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_19[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_3[0][0]            
                                                                 conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_20[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_21[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_22[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 261, 261, 8)  6280        conv2d_23[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 267, 267, 4)  1572        conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 300, 300, 2)  9250        conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 300, 300, 1)  3           conv2d_transpose_3[0][0]         
==================================================================================================
Total params: 1,983,850
Trainable params: 1,983,850
Non-trainable params: 0

我想将模型1中的权重从conv2d_6加载到末尾,然后将它们从conv2d_3加载到模型2中(如下所述)

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 502, 200, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 502, 200, 3)  30          input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 251, 200, 3)  0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 251, 200, 3)  48          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 125, 100, 3)  0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 126, 109, 3)  183         max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 127, 118, 3)  183         conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 128, 128, 3)  201         conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 16) 448         conv2d_transpose_3[0][0]         
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 64, 64, 16)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 32)   4640        max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 32)   9248        conv2d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 32, 32, 32)   0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 64)   18496       max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 64)   36928       conv2d_7[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 16, 16, 64)   0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 128)  73856       max_pooling2d_5[0][0]            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_9[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 8, 8, 128)    0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 8, 8, 256)    295168      max_pooling2d_6[0][0]            
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 8, 8, 256)    590080      conv2d_11[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 16, 16, 256)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 384)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 16, 16, 128)  442496      concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 16, 16, 128)  147584      conv2d_13[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 32, 32, 128)  0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 192)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 32, 32, 64)   110656      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 32, 32, 64)   36928       conv2d_15[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 64, 64, 64)   0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 96)   0           up_sampling2d_3[0][0]            
                                                                 conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 64, 64, 32)   27680       concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 64, 64, 32)   9248        conv2d_17[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 128, 128, 32) 0           conv2d_18[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 48) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 128, 128, 16) 6928        concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_19[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 261, 261, 8)  6280        conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 267, 267, 4)  1572        conv2d_transpose_4[0][0]         
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 300, 300, 2)  9250        conv2d_transpose_5[0][0]         
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 300, 300, 1)  3           conv2d_transpose_6[0][0]         
==================================================================================================
Total params: 1,980,358
Trainable params: 1,980,358
Non-trainable params: 0

我从预先训练过的模型中选取了weights_list = model1.get_weights(),这个列表的长度是54。我就是不明白层中权重的索引。Model1有43层,索引让我很困惑。对于其他模型的索引,是否也有一般的理解方法? 在未来的模型中,我将在模型之间选择特定的层权重


Tags: 模型noneparamsmaxtransposeupconcatenateconv2d
1条回答
网友
1楼 · 发布于 2024-06-23 19:37:17

您需要从特定层提取权重。使用model.get_weights()本身返回整个模型的权重

以下几点应该有效

model1_weights = model1.get_layer('conv2d_6').get_weights()
model2.get_layer('conv2d_3').set_weights(model1_weights)

编辑: 要将其扩展到所有所需的层,最简单的方法是为要传输的model1创建一个层名称列表(mode1_层)和要传输到的model2(mode2_层)创建一个层名称列表。然后将其传输到for循环中,如下所示

for i in range(len(model1_layers)):
    layer_weights = model1.get_layer(model1_layers[i]).get_weights()
    model2.get_layer(model2_layers[i]).set_weights(layer_weights)

相关问题 更多 >

    热门问题