在子类模型中,大多数卷积滤波器不存在梯度

2024-06-26 00:25:13 发布

您现在位置:Python中文网/ 问答频道 /正文

包含必要的文件。将此添加到“我的驱动器”中。 https://drive.google.com/drive/folders/1epROVNfvO10Ksy8CwJQdamSK96JZnWW9?usp=sharing 谷歌colab链接,最简单的例子:https://colab.research.google.com/drive/18sMqNn8IpTQLZBlInWSbX0ITd2GWDDkz?usp=sharing

如果您愿意,这个基本块“模块”是更大网络的一部分。然而,这一切归结为这个块,因为这是执行卷积的地方(在本例中,是深度可分离卷积)。网络似乎能够进行训练,但是,在训练时(以及在所有时期),会发出警告:

WARNING:tensorflow:Gradients do not exist for variables ['hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_1_upper_HG0/depthwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_1_upper_HG0/pointwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_2_upper_HG0/depthwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_2_upper_HG0/pointwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_1_upper_HG0/gamma:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_1_upper_HG0/beta:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_12_upper_HG0/gamma:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_12_upper_HG0/beta:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_1_upper_HG1/depthwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_1_upper_HG1/pointwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_2_upper_HG1/depthwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_2_upper_HG1/pointwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_1_upper_HG1/gamma:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_1_upper_HG1/beta:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_12_upper_HG1/gamma:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_12_upper_HG1/beta:0', 'batch_normalization/gamma:0', 'batch_normalization/beta:0', 'batch_normalization_1/gamma:0', 'batch_normalization_1/beta:0', 'batch_normalization_2/gamma:0', 'batch_normalization_2/beta:0', 'batch_normalization_3/gamma:0', 'batch_normalization_3/beta:0'] when minimizing the loss.

它可以归结为这个基本的块子类模型/层,特别是可分离卷积层。我不知道它为什么抱怨。层/模型是从层次结构中更高的其他子类模型调用的 正如我所说,层/模型是从层次结构中更高的其他子类模型调用的。此子类模型/层存在许多实例。其初始化与此类似:

self.upper = BasicBlock(ncIn, ncIn, batchNorm_type=1, name=f'upper_{name}')。在子类模型的init方法中

并呼吁类似的:

out_upper = self.upper([feat])在子类模型的调用方法中

我不明白这为什么会引起问题。如果您有任何想法,请随时提出

下面是一个显示该问题的简单示例。这称为minima_example.py,包含完整的网络结构

import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import tensorflow_addons as tfa
import numpy as np
import sys

#https://github.com/tensorflow/addons/blob/master/docs/tutorials/layers_normalizations.ipynb
class LightingNet(tf.keras.Model):
    def __init__(self, ncIn, ncOut, ncMiddle, dynamic=True, name=None):
        super(LightingNet, self).__init__(dynamic=True)
        self.ncIn = ncIn
        self.ncOut = ncOut
        self.ncMiddle = ncMiddle

        self.FC1 = SeparableConv2D(self.ncMiddle, kernel_size=(1,1), strides=(1,1), use_bias=False, name=f'FC1_Lighting_{name}')
        self.relu1 = PReLU()
        self.FC2 = SeparableConv2D(self.ncOut, kernel_size=(1,1), strides=(1,1), use_bias=False, name=f'FC2_Lighting_{name}')

    
    def call(self, inputs):
        feat = inputs[0]
        L_t= inputs[1]
        count= inputs[2]
        skip_count = inputs[3]
        x = feat[:,:,:,0:self.ncIn] 
        _, row, col, _ = x.shape
        f = tf.math.reduce_mean(x, axis=(1,2), keepdims=True )
        L_hat = self.relu1(self.FC1(f))
        L_hat = self.FC2(L_hat)

        return L_hat


class BasicBlock(tf.keras.Model):
    def __init__(self, ncIn, ncOut, batchNorm_type=0, strides=(1,1), downSample=None, dynamic=True, name=None):
        super(BasicBlock, self).__init__(dynamic=True)
        self.ncIn = ncIn
        self.ncOut = ncOut
        self.conv_1 = SeparableConv2D(ncOut, kernel_size=(3,3), strides=(1,1), padding="same", use_bias=False, name=f'BB_conv_1_{name}')
        self.conv_2 = SeparableConv2D(ncOut, kernel_size=(3,3), strides=(1,1), padding="same", use_bias=False, name=f'BB_conv_2_{name}') #these are the same. No idea why they decided to do a separate function for i
        if batchNorm_type == 0:
            self.bn = BatchNormalization(name=f'BN_0_{name}')
        else:
            self.bn = BatchNormalization(name=f'BN_0_{name}')
    def call(self, inputs):
            x = inputs[0]
            #print(x.shape)
            out = self.conv_1(x)
            out = self.bn(out)
            out = Activation('relu')(out) 
            out = self.conv_2(out)
            out = self.bn(out) 
            
            out = Activation('relu')(out)
            return out


class HourglassBlock(tf.keras.Model): #should it be something akin to that of NN.module but for Keras. Do they work the same?
    def __init__(self, ncIn, ncOut, next, skipLayer=True, dynamic=True, name=None):
        super(HourglassBlock, self).__init__(dynamic=True)
        self.skipLayer = True
        self.upper = BasicBlock(ncIn, ncIn, batchNorm_type=1, name=f'upper_{name}') 

        #Need a better name scheme for the layers, low1,2 are TERRIBLE names.
        self.downSample = MaxPool2D(pool_size=(2, 2), strides=(2,2), name=f'downSample_{name}')
        self.low1 = BasicBlock(ncIn, ncOut, name=f'lower_{name}')
        self.next = next  #aka middle
    
    #not sure if these require output_shapes too
    def call(self, inputs):
        feat = inputs[0]
        L_t= inputs[1]
        count = inputs[2]
        skip_count= inputs[3]
        out_upper = self.upper([feat])   
        out_lower = self.downSample(feat)  
        out_lower = self.low1([out_lower])
        L_hat = self.next([out_lower, L_t, count + 1, skip_count])
        return [L_hat] 

class HourglassNet(tf.keras.Model):
    
    def __init__(self, gray = True):
        super(HourglassNet, self).__init__(dynamic=True)

        self.nrSH_in = 27 #number of input spherical harmonics coeff.
        self.baseFilter = 16
        self.nrSH_out = 9 if gray else 27 #nr output SH
        self.ncPre = self.baseFilter #This is the amount required for the pre-convolution step

        self.ncHG3 = self.baseFilter #this is the amount of output channels for the first and last hourglass block
        self.ncHG2 = self.baseFilter * 2
        self.ncHG1 = self.baseFilter * 4
        self.ncHG0 = self.baseFilter * 8 + self.nrSH_in # Bottleneck layer. 
      
        self.pre_conv = SeparableConv2D(self.ncPre, kernel_size=(5,5), strides=(1,1), padding="same", name="pre_conv")
        self.pre_bn = BatchNormalization(name="pre_bn")  
       

        self.light = LightingNet(self.nrSH_in, self.nrSH_out, 128, "LIGHT") 
        self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light, name = "HG0") 
        self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0, name = "HG1") 
        self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1, name= "HG2") 
        self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2, name="HG3") 

    def compute_output_shape(self, input_shape):
        return [tf.TensorShape((1,1,1,9))] #Must somehow add L_hat to the dimensions

    def call(self, inputs):
        x = inputs[0]
        L_t= inputs[1]
        skip_count= 0
        feat = self.pre_conv(x)
        feat = self.pre_bn(feat)
        feat = Activation("relu")(feat)
        L_hat = self.HG3([feat, L_t, 0, skip_count])
        return [L_hat]

以下是培训代码:

import os
import sys
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
#tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from tensorflow import keras
from tensorflow.keras import layers
#import matplotlib.pyplot as plt
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
#import PIL
import numpy as np 
import cv2
#from LightingNet import HourglassNet
from minima_example import HourglassNet
#import os
import pandas as pd
import functools
import re
from itertools import chain

class LightingNet:

    def __init__(self, checkpoint_file = None):
        self.generator_model = None
        self.gan_model = self.create_generator()
        self.df_train = pd.read_csv("train.lst", sep=" ")
        self.df_valid = pd.read_csv("val.lst", sep=" ")
        self.const_path = ""
        self.train_list_ptr = 0
    
    def create_generator(self):
        self.generator_model = HourglassNet() 
        in_src = Input(shape=(512,512,1))
        in_L_t = Input(shape=(1,1,9))
        in_L_gt = Input(shape=(1,1,9))
        L_hat  = self.generator_model([in_src, in_L_t])
        self.generator_model.build([(1,512,512,1), (1,1,1,9)])
        gan_model = Model([in_src, in_L_t, in_L_gt], [])
        opt = Adam()
        loss = tf.norm(tf.math.subtract(in_L_gt, L_hat), ord=2) 
        gan_model.add_loss(loss)
        gan_model.compile(optimizer=opt)
        return gan_model

    def train(self, n_epochs = 10, n_batch = 1, n_patch = 32):
        _, inp, _ = self.df_train
        epoch = 0
        while True:
            real_samples = self.generate_real_samples(n_batch, n_patch)
            [X_real_img, X_real_sh, X_gt_sh], y_real = real_samples
            X_real_img = np.concatenate(X_real_img)
            X_real_sh = np.concatenate(X_real_sh)
            X_gt_sh = np.concatenate(X_gt_sh)
            dict_loss = self.gan_model.train_on_batch([X_real_img, X_real_sh, X_gt_sh], [], return_dict= True)
            
            if self.train_list_ptr + n_batch >= self.df_train[inp].shape[0]:
               print("epoch completed")
               self.train_list_ptr = 0 #reset list ptr 
               epoch += 1 


    def generate_real_samples(self, n_samples, patch_shape):
        """
        train list looks like: 

        DIR_NAME INPUT_NAME.png TARGET_NAME.png
        DIR_NAME INPUT_NAME.png TARGET_NAME.png
        . . .
        """
        dir, inp, targ = self.df_train
        train_inp_list = self.df_train[inp].to_numpy()
        train_targ_list = self.df_train[targ].to_numpy()

        while(True):
            index = [i for i in range(self.train_list_ptr, self.train_list_ptr + n_samples)]
            assert len(index) == n_samples 
            img_samples = train_inp_list[index]
            sh_samples = train_targ_list[index] #index file from row ptr to ptr + n_samples
            self.train_list_ptr += n_samples #move ptr by n_samples 
            if all([os.path.isdir(f'{self.const_path}DPR_dataset/{re.search("(.+?)_", sample).group(1)}') for sample in img_samples ]):
                break
        
        img_list = [cv2.imread(f'{self.const_path}DPR_dataset/{re.search("(.+?)_", img).group(1)}/{img}') for img in img_samples]
        sh_gt_list = [np.loadtxt(f'{self.const_path}DPR_dataset/{re.search("(.+?)_", img).group(1)}/{re.search("(.+?)_", img).group(1)}_light_{re.search(".+?_([0-9]+)", img).group(1)}.txt') for img in img_samples]
        sh_list = [np.loadtxt(f'{self.const_path}DPR_dataset/{re.search("(.+?)_", img).group(1)}/{re.search("(.+?)_", img).group(1)}_light_{re.search(".+?_([0-9]+)", img).group(1)}.txt') for img in sh_samples]
        funcs = [(lambda x : x[0:9]), (lambda x : x * 0.7), (lambda x : np.reshape(x, (1,1,1,9))), (lambda x : x.astype(np.float32))] 
        sh_list = [functools.reduce((lambda x, y: y(x)), funcs, sh) for sh in sh_list]
        sh_gt_list = [functools.reduce((lambda x, y: y(x)), funcs, sh) for sh in sh_gt_list]
        img_list = list(map(self.inputPreprocessing, img_list))
        y = np.ones((n_samples, patch_shape, patch_shape, 1))

        return [img_list, sh_list, sh_gt_list ], y
        

    def inputPreprocessing(self, image):
        image = cv2.resize(image, (512, 512))
        LAB = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        inputL = LAB[:, :, 0]
        inputL = inputL.astype(np.float32)/255.0
        inputL = inputL[None, ..., None]
        return inputL

if __name__ == '__main__':
    print("STARTING")
    g_model = LightingNet()
    g_model.train(1, 1, 32)

我已经测试了这段代码,它正在工作。以下文件应位于名为“imgHQ00000_00.jpg”的文件夹结构DPR_dataset/imgHQ00000/{a3}和名为“imgHQ00000_00.jpg”的txt文件中:

1.084125496282453138e+00
-4.642676300617166185e-01
2.837846795150648915e-02
6.765292733937575687e-01
-3.594067725393816914e-01
4.790996460111427574e-02
-2.280054643781863066e-01
-8.125983081159608712e-02
2.881082012687687932e-01

应该命名为“imgHQ00000_light_00.txt”

名为“train.lst”的文件也应与培训代码和minima_example.py文件位于同一级别上,并且应包含以下行(并且仅包含这些行):

imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg

执行此操作:

    for i, var in enumerate(self.generator_model.trainable_variables):
        print(self.generator_model.trainable_variables[i].name  + ':                                   ' + str(tf.norm(self.generator_model.trainable_variables[i]).numpy()))

提供以下信息:

pre_conv/depthwise_kernel:0:                                   0.9576778
pre_conv/pointwise_kernel:0:                                   1.464122
pre_conv/bias:0:                                   0.0009615466
pre_bn/gamma:0:                                   3.9996755
pre_bn/beta:0:                                   0.0038261134
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/FC1_Lighting_None/depthwise_kernel:0:                                   1.4065807
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/FC1_Lighting_None/pointwise_kernel:0:                                   6.7508173
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/p_re_lu/alpha:0:                                   0.0076647564
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/FC2_Lighting_None/depthwise_kernel:0:                                   1.3878155
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/FC2_Lighting_None/pointwise_kernel:0:                                   4.050495
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_1_upper_HG0/depthwise_kernel:0:                                   1.4663663
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_1_upper_HG0/pointwise_kernel:0:                                   8.038067
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_2_upper_HG0/depthwise_kernel:0:                                   1.39229
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_2_upper_HG0/pointwise_kernel:0:                                   8.036673
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_0_upper_HG0/gamma:0:                                   8.0
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_0_upper_HG0/beta:0:                                   0.0
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BB_conv_1_lower_HG0/depthwise_kernel:0:                                   1.3681538
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BB_conv_1_lower_HG0/pointwise_kernel:0:                                   9.556786
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BB_conv_2_lower_HG0/depthwise_kernel:0:                                   1.3960428
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BB_conv_2_lower_HG0/pointwise_kernel:0:                                   12.413813
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BN_0_lower_HG0/gamma:0:                                   12.4494915
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BN_0_lower_HG0/beta:0:                                   0.010747912
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BB_conv_1_upper_HG1/depthwise_kernel:0:                                   1.3634498
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BB_conv_1_upper_HG1/pointwise_kernel:0:                                   5.6760592
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BB_conv_2_upper_HG1/depthwise_kernel:0:                                   1.3929691
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BB_conv_2_upper_HG1/pointwise_kernel:0:                                   5.6887717
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BN_0_upper_HG1/gamma:0:                                   5.656854
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BN_0_upper_HG1/beta:0:                                   0.0
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_1_lower_HG1/depthwise_kernel:0:                                   1.3567474
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_1_lower_HG1/pointwise_kernel:0:                                   6.561019
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_2_lower_HG1/depthwise_kernel:0:                                   1.408531
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_2_lower_HG1/pointwise_kernel:0:                                   8.045188
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_0_lower_HG1/gamma:0:                                   7.9996595
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_0_lower_HG1/beta:0:                                   0.007245279
hourglass_block_3/hourglass_block_2/basic_block_4/BB_conv_1_upper_HG2/depthwise_kernel:0:                                   1.3706092
hourglass_block_3/hourglass_block_2/basic_block_4/BB_conv_1_upper_HG2/pointwise_kernel:0:                                   3.8657312
hourglass_block_3/hourglass_block_2/basic_block_4/BB_conv_2_upper_HG2/depthwise_kernel:0:                                   1.2623239
hourglass_block_3/hourglass_block_2/basic_block_4/BB_conv_2_upper_HG2/pointwise_kernel:0:                                   3.9968336
hourglass_block_3/hourglass_block_2/basic_block_4/BN_0_upper_HG2/gamma:0:                                   4.0
hourglass_block_3/hourglass_block_2/basic_block_4/BN_0_upper_HG2/beta:0:                                   0.0
hourglass_block_3/hourglass_block_2/basic_block_5/BB_conv_1_lower_HG2/depthwise_kernel:0:                                   1.3684276
hourglass_block_3/hourglass_block_2/basic_block_5/BB_conv_1_lower_HG2/pointwise_kernel:0:                                   4.6307783
hourglass_block_3/hourglass_block_2/basic_block_5/BB_conv_2_lower_HG2/depthwise_kernel:0:                                   1.3033983
hourglass_block_3/hourglass_block_2/basic_block_5/BB_conv_2_lower_HG2/pointwise_kernel:0:                                   5.725787
hourglass_block_3/hourglass_block_2/basic_block_5/BN_0_lower_HG2/gamma:0:                                   5.656453
hourglass_block_3/hourglass_block_2/basic_block_5/BN_0_lower_HG2/beta:0:                                   0.005336759
hourglass_block_3/basic_block_6/BB_conv_1_upper_HG3/depthwise_kernel:0:                                   1.4134688
hourglass_block_3/basic_block_6/BB_conv_1_upper_HG3/pointwise_kernel:0:                                   3.8999567
hourglass_block_3/basic_block_6/BB_conv_2_upper_HG3/depthwise_kernel:0:                                   1.4348195
hourglass_block_3/basic_block_6/BB_conv_2_upper_HG3/pointwise_kernel:0:                                   4.0438123
hourglass_block_3/basic_block_6/BN_0_upper_HG3/gamma:0:                                   4.0
hourglass_block_3/basic_block_6/BN_0_upper_HG3/beta:0:                                   0.0
hourglass_block_3/basic_block_7/BB_conv_1_lower_HG3/depthwise_kernel:0:                                   1.4054896
hourglass_block_3/basic_block_7/BB_conv_1_lower_HG3/pointwise_kernel:0:                                   3.9783711
hourglass_block_3/basic_block_7/BB_conv_2_lower_HG3/depthwise_kernel:0:                                   1.3978381
hourglass_block_3/basic_block_7/BB_conv_2_lower_HG3/pointwise_kernel:0:                                   3.8882372
hourglass_block_3/basic_block_7/BN_0_lower_HG3/gamma:0:                                   4.0003786
hourglass_block_3/basic_block_7/BN_0_lower_HG3/beta:0:                                   0.0038626757

Tags: nameimportselfbasicblockupperkernellower