包含必要的文件。将此添加到“我的驱动器”中。 https://drive.google.com/drive/folders/1epROVNfvO10Ksy8CwJQdamSK96JZnWW9?usp=sharing 谷歌colab链接,最简单的例子:https://colab.research.google.com/drive/18sMqNn8IpTQLZBlInWSbX0ITd2GWDDkz?usp=sharing
如果您愿意,这个基本块“模块”是更大网络的一部分。然而,这一切归结为这个块,因为这是执行卷积的地方(在本例中,是深度可分离卷积)。网络似乎能够进行训练,但是,在训练时(以及在所有时期),会发出警告:
WARNING:tensorflow:Gradients do not exist for variables ['hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_1_upper_HG0/depthwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_1_upper_HG0/pointwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_2_upper_HG0/depthwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_2_upper_HG0/pointwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_1_upper_HG0/gamma:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_1_upper_HG0/beta:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_12_upper_HG0/gamma:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_12_upper_HG0/beta:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_1_upper_HG1/depthwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_1_upper_HG1/pointwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_2_upper_HG1/depthwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_2_upper_HG1/pointwise_kernel:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_1_upper_HG1/gamma:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_1_upper_HG1/beta:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_12_upper_HG1/gamma:0', 'hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_12_upper_HG1/beta:0', 'batch_normalization/gamma:0', 'batch_normalization/beta:0', 'batch_normalization_1/gamma:0', 'batch_normalization_1/beta:0', 'batch_normalization_2/gamma:0', 'batch_normalization_2/beta:0', 'batch_normalization_3/gamma:0', 'batch_normalization_3/beta:0'] when minimizing the loss.
它可以归结为这个基本的块子类模型/层,特别是可分离卷积层。我不知道它为什么抱怨。层/模型是从层次结构中更高的其他子类模型调用的 正如我所说,层/模型是从层次结构中更高的其他子类模型调用的。此子类模型/层存在许多实例。其初始化与此类似:
self.upper = BasicBlock(ncIn, ncIn, batchNorm_type=1, name=f'upper_{name}')
。在子类模型的init方法中
并呼吁类似的:
out_upper = self.upper([feat])
在子类模型的调用方法中
我不明白这为什么会引起问题。如果您有任何想法,请随时提出
下面是一个显示该问题的简单示例。这称为minima_example.py
,包含完整的网络结构
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import tensorflow_addons as tfa
import numpy as np
import sys
#https://github.com/tensorflow/addons/blob/master/docs/tutorials/layers_normalizations.ipynb
class LightingNet(tf.keras.Model):
def __init__(self, ncIn, ncOut, ncMiddle, dynamic=True, name=None):
super(LightingNet, self).__init__(dynamic=True)
self.ncIn = ncIn
self.ncOut = ncOut
self.ncMiddle = ncMiddle
self.FC1 = SeparableConv2D(self.ncMiddle, kernel_size=(1,1), strides=(1,1), use_bias=False, name=f'FC1_Lighting_{name}')
self.relu1 = PReLU()
self.FC2 = SeparableConv2D(self.ncOut, kernel_size=(1,1), strides=(1,1), use_bias=False, name=f'FC2_Lighting_{name}')
def call(self, inputs):
feat = inputs[0]
L_t= inputs[1]
count= inputs[2]
skip_count = inputs[3]
x = feat[:,:,:,0:self.ncIn]
_, row, col, _ = x.shape
f = tf.math.reduce_mean(x, axis=(1,2), keepdims=True )
L_hat = self.relu1(self.FC1(f))
L_hat = self.FC2(L_hat)
return L_hat
class BasicBlock(tf.keras.Model):
def __init__(self, ncIn, ncOut, batchNorm_type=0, strides=(1,1), downSample=None, dynamic=True, name=None):
super(BasicBlock, self).__init__(dynamic=True)
self.ncIn = ncIn
self.ncOut = ncOut
self.conv_1 = SeparableConv2D(ncOut, kernel_size=(3,3), strides=(1,1), padding="same", use_bias=False, name=f'BB_conv_1_{name}')
self.conv_2 = SeparableConv2D(ncOut, kernel_size=(3,3), strides=(1,1), padding="same", use_bias=False, name=f'BB_conv_2_{name}') #these are the same. No idea why they decided to do a separate function for i
if batchNorm_type == 0:
self.bn = BatchNormalization(name=f'BN_0_{name}')
else:
self.bn = BatchNormalization(name=f'BN_0_{name}')
def call(self, inputs):
x = inputs[0]
#print(x.shape)
out = self.conv_1(x)
out = self.bn(out)
out = Activation('relu')(out)
out = self.conv_2(out)
out = self.bn(out)
out = Activation('relu')(out)
return out
class HourglassBlock(tf.keras.Model): #should it be something akin to that of NN.module but for Keras. Do they work the same?
def __init__(self, ncIn, ncOut, next, skipLayer=True, dynamic=True, name=None):
super(HourglassBlock, self).__init__(dynamic=True)
self.skipLayer = True
self.upper = BasicBlock(ncIn, ncIn, batchNorm_type=1, name=f'upper_{name}')
#Need a better name scheme for the layers, low1,2 are TERRIBLE names.
self.downSample = MaxPool2D(pool_size=(2, 2), strides=(2,2), name=f'downSample_{name}')
self.low1 = BasicBlock(ncIn, ncOut, name=f'lower_{name}')
self.next = next #aka middle
#not sure if these require output_shapes too
def call(self, inputs):
feat = inputs[0]
L_t= inputs[1]
count = inputs[2]
skip_count= inputs[3]
out_upper = self.upper([feat])
out_lower = self.downSample(feat)
out_lower = self.low1([out_lower])
L_hat = self.next([out_lower, L_t, count + 1, skip_count])
return [L_hat]
class HourglassNet(tf.keras.Model):
def __init__(self, gray = True):
super(HourglassNet, self).__init__(dynamic=True)
self.nrSH_in = 27 #number of input spherical harmonics coeff.
self.baseFilter = 16
self.nrSH_out = 9 if gray else 27 #nr output SH
self.ncPre = self.baseFilter #This is the amount required for the pre-convolution step
self.ncHG3 = self.baseFilter #this is the amount of output channels for the first and last hourglass block
self.ncHG2 = self.baseFilter * 2
self.ncHG1 = self.baseFilter * 4
self.ncHG0 = self.baseFilter * 8 + self.nrSH_in # Bottleneck layer.
self.pre_conv = SeparableConv2D(self.ncPre, kernel_size=(5,5), strides=(1,1), padding="same", name="pre_conv")
self.pre_bn = BatchNormalization(name="pre_bn")
self.light = LightingNet(self.nrSH_in, self.nrSH_out, 128, "LIGHT")
self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light, name = "HG0")
self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0, name = "HG1")
self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1, name= "HG2")
self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2, name="HG3")
def compute_output_shape(self, input_shape):
return [tf.TensorShape((1,1,1,9))] #Must somehow add L_hat to the dimensions
def call(self, inputs):
x = inputs[0]
L_t= inputs[1]
skip_count= 0
feat = self.pre_conv(x)
feat = self.pre_bn(feat)
feat = Activation("relu")(feat)
L_hat = self.HG3([feat, L_t, 0, skip_count])
return [L_hat]
以下是培训代码:
import os
import sys
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
#tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from tensorflow import keras
from tensorflow.keras import layers
#import matplotlib.pyplot as plt
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
#import PIL
import numpy as np
import cv2
#from LightingNet import HourglassNet
from minima_example import HourglassNet
#import os
import pandas as pd
import functools
import re
from itertools import chain
class LightingNet:
def __init__(self, checkpoint_file = None):
self.generator_model = None
self.gan_model = self.create_generator()
self.df_train = pd.read_csv("train.lst", sep=" ")
self.df_valid = pd.read_csv("val.lst", sep=" ")
self.const_path = ""
self.train_list_ptr = 0
def create_generator(self):
self.generator_model = HourglassNet()
in_src = Input(shape=(512,512,1))
in_L_t = Input(shape=(1,1,9))
in_L_gt = Input(shape=(1,1,9))
L_hat = self.generator_model([in_src, in_L_t])
self.generator_model.build([(1,512,512,1), (1,1,1,9)])
gan_model = Model([in_src, in_L_t, in_L_gt], [])
opt = Adam()
loss = tf.norm(tf.math.subtract(in_L_gt, L_hat), ord=2)
gan_model.add_loss(loss)
gan_model.compile(optimizer=opt)
return gan_model
def train(self, n_epochs = 10, n_batch = 1, n_patch = 32):
_, inp, _ = self.df_train
epoch = 0
while True:
real_samples = self.generate_real_samples(n_batch, n_patch)
[X_real_img, X_real_sh, X_gt_sh], y_real = real_samples
X_real_img = np.concatenate(X_real_img)
X_real_sh = np.concatenate(X_real_sh)
X_gt_sh = np.concatenate(X_gt_sh)
dict_loss = self.gan_model.train_on_batch([X_real_img, X_real_sh, X_gt_sh], [], return_dict= True)
if self.train_list_ptr + n_batch >= self.df_train[inp].shape[0]:
print("epoch completed")
self.train_list_ptr = 0 #reset list ptr
epoch += 1
def generate_real_samples(self, n_samples, patch_shape):
"""
train list looks like:
DIR_NAME INPUT_NAME.png TARGET_NAME.png
DIR_NAME INPUT_NAME.png TARGET_NAME.png
. . .
"""
dir, inp, targ = self.df_train
train_inp_list = self.df_train[inp].to_numpy()
train_targ_list = self.df_train[targ].to_numpy()
while(True):
index = [i for i in range(self.train_list_ptr, self.train_list_ptr + n_samples)]
assert len(index) == n_samples
img_samples = train_inp_list[index]
sh_samples = train_targ_list[index] #index file from row ptr to ptr + n_samples
self.train_list_ptr += n_samples #move ptr by n_samples
if all([os.path.isdir(f'{self.const_path}DPR_dataset/{re.search("(.+?)_", sample).group(1)}') for sample in img_samples ]):
break
img_list = [cv2.imread(f'{self.const_path}DPR_dataset/{re.search("(.+?)_", img).group(1)}/{img}') for img in img_samples]
sh_gt_list = [np.loadtxt(f'{self.const_path}DPR_dataset/{re.search("(.+?)_", img).group(1)}/{re.search("(.+?)_", img).group(1)}_light_{re.search(".+?_([0-9]+)", img).group(1)}.txt') for img in img_samples]
sh_list = [np.loadtxt(f'{self.const_path}DPR_dataset/{re.search("(.+?)_", img).group(1)}/{re.search("(.+?)_", img).group(1)}_light_{re.search(".+?_([0-9]+)", img).group(1)}.txt') for img in sh_samples]
funcs = [(lambda x : x[0:9]), (lambda x : x * 0.7), (lambda x : np.reshape(x, (1,1,1,9))), (lambda x : x.astype(np.float32))]
sh_list = [functools.reduce((lambda x, y: y(x)), funcs, sh) for sh in sh_list]
sh_gt_list = [functools.reduce((lambda x, y: y(x)), funcs, sh) for sh in sh_gt_list]
img_list = list(map(self.inputPreprocessing, img_list))
y = np.ones((n_samples, patch_shape, patch_shape, 1))
return [img_list, sh_list, sh_gt_list ], y
def inputPreprocessing(self, image):
image = cv2.resize(image, (512, 512))
LAB = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
inputL = LAB[:, :, 0]
inputL = inputL.astype(np.float32)/255.0
inputL = inputL[None, ..., None]
return inputL
if __name__ == '__main__':
print("STARTING")
g_model = LightingNet()
g_model.train(1, 1, 32)
我已经测试了这段代码,它正在工作。以下文件应位于名为“imgHQ00000_00.jpg”的文件夹结构DPR_dataset/imgHQ00000/
{a3}和名为“imgHQ00000_00.jpg”的txt文件中:
1.084125496282453138e+00
-4.642676300617166185e-01
2.837846795150648915e-02
6.765292733937575687e-01
-3.594067725393816914e-01
4.790996460111427574e-02
-2.280054643781863066e-01
-8.125983081159608712e-02
2.881082012687687932e-01
应该命名为“imgHQ00000_light_00.txt”
名为“train.lst”的文件也应与培训代码和minima_example.py文件位于同一级别上,并且应包含以下行(并且仅包含这些行):
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
imgHQ00000 imgHQ00000_00.jpg imgHQ00000_00.jpg
执行此操作:
for i, var in enumerate(self.generator_model.trainable_variables):
print(self.generator_model.trainable_variables[i].name + ': ' + str(tf.norm(self.generator_model.trainable_variables[i]).numpy()))
提供以下信息:
pre_conv/depthwise_kernel:0: 0.9576778
pre_conv/pointwise_kernel:0: 1.464122
pre_conv/bias:0: 0.0009615466
pre_bn/gamma:0: 3.9996755
pre_bn/beta:0: 0.0038261134
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/FC1_Lighting_None/depthwise_kernel:0: 1.4065807
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/FC1_Lighting_None/pointwise_kernel:0: 6.7508173
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/p_re_lu/alpha:0: 0.0076647564
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/FC2_Lighting_None/depthwise_kernel:0: 1.3878155
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/lighting_net/FC2_Lighting_None/pointwise_kernel:0: 4.050495
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_1_upper_HG0/depthwise_kernel:0: 1.4663663
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_1_upper_HG0/pointwise_kernel:0: 8.038067
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_2_upper_HG0/depthwise_kernel:0: 1.39229
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BB_conv_2_upper_HG0/pointwise_kernel:0: 8.036673
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_0_upper_HG0/gamma:0: 8.0
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block/BN_0_upper_HG0/beta:0: 0.0
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BB_conv_1_lower_HG0/depthwise_kernel:0: 1.3681538
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BB_conv_1_lower_HG0/pointwise_kernel:0: 9.556786
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BB_conv_2_lower_HG0/depthwise_kernel:0: 1.3960428
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BB_conv_2_lower_HG0/pointwise_kernel:0: 12.413813
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BN_0_lower_HG0/gamma:0: 12.4494915
hourglass_block_3/hourglass_block_2/hourglass_block_1/hourglass_block/basic_block_1/BN_0_lower_HG0/beta:0: 0.010747912
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BB_conv_1_upper_HG1/depthwise_kernel:0: 1.3634498
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BB_conv_1_upper_HG1/pointwise_kernel:0: 5.6760592
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BB_conv_2_upper_HG1/depthwise_kernel:0: 1.3929691
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BB_conv_2_upper_HG1/pointwise_kernel:0: 5.6887717
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BN_0_upper_HG1/gamma:0: 5.656854
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_2/BN_0_upper_HG1/beta:0: 0.0
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_1_lower_HG1/depthwise_kernel:0: 1.3567474
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_1_lower_HG1/pointwise_kernel:0: 6.561019
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_2_lower_HG1/depthwise_kernel:0: 1.408531
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BB_conv_2_lower_HG1/pointwise_kernel:0: 8.045188
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_0_lower_HG1/gamma:0: 7.9996595
hourglass_block_3/hourglass_block_2/hourglass_block_1/basic_block_3/BN_0_lower_HG1/beta:0: 0.007245279
hourglass_block_3/hourglass_block_2/basic_block_4/BB_conv_1_upper_HG2/depthwise_kernel:0: 1.3706092
hourglass_block_3/hourglass_block_2/basic_block_4/BB_conv_1_upper_HG2/pointwise_kernel:0: 3.8657312
hourglass_block_3/hourglass_block_2/basic_block_4/BB_conv_2_upper_HG2/depthwise_kernel:0: 1.2623239
hourglass_block_3/hourglass_block_2/basic_block_4/BB_conv_2_upper_HG2/pointwise_kernel:0: 3.9968336
hourglass_block_3/hourglass_block_2/basic_block_4/BN_0_upper_HG2/gamma:0: 4.0
hourglass_block_3/hourglass_block_2/basic_block_4/BN_0_upper_HG2/beta:0: 0.0
hourglass_block_3/hourglass_block_2/basic_block_5/BB_conv_1_lower_HG2/depthwise_kernel:0: 1.3684276
hourglass_block_3/hourglass_block_2/basic_block_5/BB_conv_1_lower_HG2/pointwise_kernel:0: 4.6307783
hourglass_block_3/hourglass_block_2/basic_block_5/BB_conv_2_lower_HG2/depthwise_kernel:0: 1.3033983
hourglass_block_3/hourglass_block_2/basic_block_5/BB_conv_2_lower_HG2/pointwise_kernel:0: 5.725787
hourglass_block_3/hourglass_block_2/basic_block_5/BN_0_lower_HG2/gamma:0: 5.656453
hourglass_block_3/hourglass_block_2/basic_block_5/BN_0_lower_HG2/beta:0: 0.005336759
hourglass_block_3/basic_block_6/BB_conv_1_upper_HG3/depthwise_kernel:0: 1.4134688
hourglass_block_3/basic_block_6/BB_conv_1_upper_HG3/pointwise_kernel:0: 3.8999567
hourglass_block_3/basic_block_6/BB_conv_2_upper_HG3/depthwise_kernel:0: 1.4348195
hourglass_block_3/basic_block_6/BB_conv_2_upper_HG3/pointwise_kernel:0: 4.0438123
hourglass_block_3/basic_block_6/BN_0_upper_HG3/gamma:0: 4.0
hourglass_block_3/basic_block_6/BN_0_upper_HG3/beta:0: 0.0
hourglass_block_3/basic_block_7/BB_conv_1_lower_HG3/depthwise_kernel:0: 1.4054896
hourglass_block_3/basic_block_7/BB_conv_1_lower_HG3/pointwise_kernel:0: 3.9783711
hourglass_block_3/basic_block_7/BB_conv_2_lower_HG3/depthwise_kernel:0: 1.3978381
hourglass_block_3/basic_block_7/BB_conv_2_lower_HG3/pointwise_kernel:0: 3.8882372
hourglass_block_3/basic_block_7/BN_0_lower_HG3/gamma:0: 4.0003786
hourglass_block_3/basic_block_7/BN_0_lower_HG3/beta:0: 0.0038626757
通过重新构建网络并将所有层依次放置而不是一个模型的多个实例来解决此问题。所以从开始到结束的一切都在一个子类模型中
相关问题 更多 >
编程相关推荐