Tensorflow:GAN模型中的权重和偏差不适用

2024-09-25 00:28:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我做了这个GAN来恢复WSN中丢失的数据。由于某些原因,此网络中的权重和偏差不会更新。发电机和鉴别器的损耗保持不变。我用了很多方法试图解决这个问题。他们都没有工作。这是我的代码:

import tensorflow as tf
import time
import numpy as np
import os
import pandas as pd
import csv
import matplotlib.pyplot as plt
from tensorflow.python.ops import variables
from tensorflow.python.framework import ops
from sklearn.preprocessing import MinMaxScaler


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.) 
    return tf.random_normal(shape=size, stddev=xavier_stddev)


# 载入数据
timeline = pd.read_csv('./all_final.csv', usecols=[0], engine='python')
dataframe = pd.read_csv('./all_final.csv', usecols=[1], engine='python')
dataType = pd.read_csv('./all_final.csv', usecols=[2], engine='python')
dataset = dataframe.values
datasample = dataType.values
# 将整型变为float
dataset = dataset.astype('float32')
datasample = datasample.astype('float32')
# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
dataset = scaler.fit_transform(dataset)


dataset = dataset.reshape([2000, 100])
datasample = datasample.reshape([2000, 100])

trainG_size = int(len(dataset) * 0.4)
trainG_list = dataset[:trainG_size]
trainG_mask = datasample[:trainG_size]
trainD_list = dataset[trainG_size:]

batch_size = 25
batch_size_D = 25


x = tf.placeholder('float', [None, 100])
y_ = tf.placeholder('float', [None, 100])
x_real = tf.placeholder('float', [None, 100])

x_image = tf.reshape(x, [-1, 1, 100, 1])
x_r = tf.reshape(x_real, [-1, 1, 100, 1])


y_reverse = tf.subtract(tf.ones([batch_size, 100], dtype=tf.float32), y_)
y_reverse = tf.reshape(y_reverse, [-1, 1, 100, 1])

y_1 = tf.reshape(y_, [-1, 1, 100, 1])

# Generator
filter1 = tf.Variable(xavier_init([1, 5, 1, 4]))
bias1 = tf.Variable(xavier_init([4]))
conv1 = tf.nn.conv2d(x_image, filter1, strides=[1, 1, 1, 1], padding='SAME')
h_conv1 = tf.nn.bias_add(conv1, bias1)

dilate_filter1 = tf.Variable(xavier_init([1, 3, 4, 4]))
dilate_bias1 = tf.Variable(xavier_init([4]))
dilate_conv1 = tf.nn.atrous_conv2d(h_conv1, dilate_filter1, 4, padding='SAME')
dilate_h_conv1 = tf.nn.bias_add(dilate_conv1, dilate_bias1)
dilate_h_conv1 = tf.nn.leaky_relu(dilate_h_conv1)


# resize back
filter11 = tf.Variable(xavier_init([1, 3, 4, 1]))
bias11 = tf.Variable(xavier_init([1]))
conv11 = tf.nn.conv2d(dilate_h_conv1, filter11, strides=[1, 1, 1, 1], padding='SAME')
h_conv11 = tf.nn.bias_add(conv11, bias11)
h_conv11 = tf.nn.leaky_relu(h_conv11)
h_conv11_f = tf.reshape(h_conv11, [batch_size, 100])

x_fake1 = tf.multiply(h_conv11, y_1) + tf.multiply(x_image, y_reverse)
x_fake = tf.reshape(x_fake1, [-1, 1, 100, 1])



theta_G = [filter1, bias1, dilate_filter1, dilate_bias1, filter11, bias11]

# Discriminator
dfilter1 = tf.Variable(xavier_init([1, 5, 1, 2]))
dbias1 = tf.Variable(xavier_init([2]))

dfilter2 = tf.Variable(xavier_init([1, 5, 2, 4]))
dbias2 = tf.Variable(xavier_init([4]))

linearw = tf.Variable(xavier_init([88, 1]))
linearb = tf.Variable(xavier_init([1]))


def discriminator(xx):
    dconv1 = tf.nn.conv2d(xx, dfilter1, strides=[1, 1, 2, 1], padding='VALID')
    dh_conv1 = tf.nn.bias_add(dconv1, dbias1)
    dh_conv1 = tf.nn.leaky_relu(dh_conv1)
    dconv2 = tf.nn.conv2d(dh_conv1, dfilter2, strides=[1, 1, 2, 1], padding='VALID')
    dh_conv2 = tf.nn.bias_add(dconv2, dbias2)
    dh_conv2 = tf.nn.leaky_relu(dh_conv2)
    flatten = tf.contrib.layers.flatten(dh_conv2)
    output1 = tf.matmul(flatten, linearw) + linearb
    output1 = tf.nn.sigmoid(output1)

    return output1

output = discriminator(x_r)
output_fake = discriminator(x_fake1)

theta_D = [dfilter1, dbias1, dfilter2, dbias2, linearw, linearb]

D_loss_real = tf.reduce_mean(tf.keras.backend.binary_crossentropy(output=output, target=tf.ones_like(output)))
D_loss_fake = tf.reduce_mean(tf.keras.backend.binary_crossentropy(output=output_fake, target=tf.zeros_like(output_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.keras.backend.binary_crossentropy(output=output_fake, target=tf.ones_like(output_fake)))

train_step_D = tf.train.AdamOptimizer(0.0002, 0.5).minimize(D_loss)
train_step_G = tf.train.AdamOptimizer(0.0002, 0.5).minimize(G_loss)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    e = []
    f = []

    # get all varibles
    var_list = (variables.trainable_variables() + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    print('variables')
    for v in var_list:
        print('  ', v.name)

    for it in range(100):
        start_time = time.time()

        rnd_indices = np.random.randint(0, len(trainG_list), batch_size)
        x_train = trainG_list[rnd_indices]
        y_train = trainG_mask[rnd_indices]

        rnd_indices_D = np.random.randint(0, len(trainD_list), batch_size_D)
        x_r2 = trainD_list[rnd_indices_D]

        Gloss, Dloss, gdata, weight = sess.run([G_loss, D_loss, x_fake1, linearb],
                                               feed_dict={x: x_train, y_: y_train, x_real: x_r2})
        Gloss, Dloss, gdata, weight = sess.run([G_loss, D_loss, x_fake1, linearb],
                                               feed_dict={x: x_train, y_: y_train, x_real: x_r2})
        e.append(it + 1)
        f.append(Gloss)
        print('Iteration:', it)
        print('G_loss', Gloss)
        print('D_loss', Dloss)
        end_time = time.time()
        print('time: ', (end_time - start_time))
        start_time = end_time
    generated, real, fake = sess.run([x_fake1, output, output_fake], feed_dict={x: trainG_list[0:25, :], y_: trainG_mask[0:25, :], x_real: trainD_list})
    generated = generated.reshape([25, 100])
    generated = scaler.inverse_transform(generated)
    print(generated[0])
    print(real)
    print(fake)
    plt.plot(e, f)
    plt.show()

以下是我使用的数据集的一部分:

Unnamed: 0,sensor_01,mask
0,47.09201,0
1,47.09201,0
2,47.35243,0
3,47.09201,0
4,48.22405178,1
5,47.09201,0
6,48.22405178,1
7,47.13541,0
8,48.22405178,1
9,47.17882,0
10,47.48264,0
11,47.91666,0
12,48.22405178,1
13,48.4375,0
14,48.56771,0
15,48.3941,0
16,48.3941,0
17,48.4809,0
18,48.61111,0
19,48.22405178,1
20,49.08854,0
21,49.21875,0
22,48.78472,0
23,49.08854,0
24,48.22405178,1
25,48.22405178,1
26,49.08854,0
27,48.22405178,1
28,48.35069,0
29,48.22405178,1
30,48.22405178,1
31,47.78646,0
32,48.22405178,1
33,48.4375,0
34,48.22405178,1
35,48.22405178,1
36,48.22405178,1
37,48.69791,0
38,48.65451,0
39,48.82813,0
40,49.26215,0
41,49.26215,0
42,49.04514,0
43,48.22405178,1
44,48.87152,0
45,49.08854,0
46,49.13194,0
47,48.22405178,1
48,49.26215,0
49,49.04514,0
50,49.43576,0
51,49.34896,0
52,48.22405178,1
53,49.30555,0
54,49.30555,0
55,49.26215,0
56,49.26215,0
57,49.43576,0
58,49.21875,0
59,49.30555,0
60,49.08854,0
61,49.08854,0
62,49.04514,0
63,48.22405178,1
64,47.78646,0
65,48.22405178,1
66,47.09201,0
67,47.74305,0
68,46.48438,0
69,48.22405178,1
70,48.04688,0
71,48.22405178,1
72,48.4375,0
73,47.78646,0
74,48.22405178,1
75,47.69965,0
76,47.56944,0
77,47.35243,0
78,47.48264,0
79,48.22405178,1
80,47.30902,0
81,48.22405178,1
82,47.26563,0
83,47.56944,0
84,47.26563,0
85,47.56944,0
86,48.22405178,1
87,47.30902,0
88,47.30902,0
89,47.35243,0
90,47.26563,0
91,47.26563,0
92,47.48264,0
93,47.69965,0
94,47.78646,0
95,47.91666,0
96,47.91666,0
97,48.22405178,1
98,48.17708,0
99,48.22405178,1
100,48.26389,0
101,48.35069,0
102,48.4375,0
103,48.35069,0
104,48.3941,0
105,48.90601027,1
106,48.61111,0
107,48.78472,0
108,48.82813,0
109,48.91493,0
110,48.90601027,1
111,48.82813,0
112,48.78472,0
113,48.69791,0
114,48.90601027,1
115,48.90601027,1
116,48.90601027,1
117,48.65451,0
118,48.4809,0
119,48.90601027,1
120,48.78472,0
121,48.78472,0
122,48.90601027,1
123,48.65451,0
124,48.69791,0
125,48.90601027,1
126,48.78472,0
127,48.90601027,1
128,48.91493,0
129,48.90601027,1
130,48.91493,0
131,48.91493,0
132,48.87152,0
133,48.82813,0
134,48.87152,0
135,48.87152,0
136,48.90601027,1
137,49.08854,0
138,48.90601027,1
139,49.08854,0
140,49.13194,0
141,49.13194,0
142,49.21875,0
143,49.26215,0
144,49.34896,0
145,48.90601027,1
146,49.34896,0
147,48.90601027,1
148,49.43576,0
149,49.43576,0
150,49.43576,0
151,48.90601027,1
152,49.56597,0
153,49.47916,0
154,49.52257,0
155,49.47916,0
156,49.47916,0
157,49.47916,0
158,49.52257,0
159,49.52257,0
160,49.52257,0
161,48.90601027,1
162,49.43576,0
163,49.26215,0
164,48.90601027,1
165,48.26389,0
166,47.69965,0
167,47.48264,0
168,48.90601027,1
169,48.90601027,1
170,48.69791,0
171,48.78472,0
172,48.87152,0
173,48.91493,0
174,48.61111,0
175,48.56771,0
176,48.87152,0
177,48.90601027,1
178,48.87152,0
179,48.90601027,1
180,49.13194,0
181,48.90601027,1
182,49.08854,0
183,49.00174,0
184,49.00174,0
185,48.90601027,1
186,48.91493,0
187,48.90601027,1
188,48.3941,0
189,48.26389,0
190,48.90601027,1
191,48.90601027,1
192,48.91493,0
193,48.87152,0
194,48.90601027,1
195,48.87152,0
196,48.91493,0
197,48.82813,0
198,48.82813,0
199,49.00174,0
200,48.37831303,1
201,49.04514,0
202,48.69791,0
203,48.56771,0
204,48.37831303,1
205,48.4375,0
206,48.37831303,1
207,48.61111,0
208,48.37831303,1
209,48.78472,0
210,48.37831303,1
211,48.65451,0
212,48.82813,0
213,48.37831303,1
214,48.61111,0
215,48.3941,0
216,48.13368,0
217,47.69965,0
218,47.48264,0
219,47.39583,0
220,48.37831303,1
221,48.22049,0
222,48.37831303,1
223,48.37831303,1
224,48.61111,0
225,48.65451,0
226,48.37831303,1
227,48.61111,0
228,48.56771,0
229,48.78472,0
230,48.37831303,1
231,48.87152,0
232,48.69791,0
233,48.37831303,1
234,48.37831303,1
235,48.82813,0
236,48.69791,0
237,48.37831303,1
238,48.37831303,1
239,48.87152,0
240,48.78472,0
241,48.37831303,1
242,48.82813,0
243,48.78472,0
244,48.37831303,1
245,48.78472,0
246,48.4375,0
247,48.56771,0
248,48.61111,0
249,48.78472,0
250,48.69791,0
251,48.91493,0
252,48.61111,0
253,48.4809,0
254,48.61111,0
255,48.37831303,1
256,48.4809,0
257,48.56771,0
258,48.37831303,1
259,48.87152,0
260,48.37831303,1
261,48.37831303,1
262,48.82813,0
263,48.69791,0
264,48.69791,0
265,48.37831303,1
266,48.69791,0
267,48.78472,0
268,48.37831303,1
269,48.65451,0
270,48.37831303,1
271,47.96007,0
272,47.69965,0
273,47.74305,0
274,47.26563,0
275,47.04861,0
276,46.8316,0
277,48.37831303,1
278,46.70139,0
279,46.74479,0
280,48.37831303,1
281,47.69965,0
282,48.37831303,1
283,48.04688,0
284,48.17708,0
285,48.17708,0
286,48.37831303,1
287,48.37831303,1
288,48.37831303,1
289,48.26389,0
290,48.26389,0
291,48.37831303,1
292,48.4375,0
293,48.4375,0
294,48.4809,0
295,48.56771,0
296,48.37831303,1
297,48.4809,0
298,48.37831303,1
299,48.37831303,1
300,48.35069,0
301,48.35069,0
302,48.17708,0
303,48.04688,0
304,47.78646,0
305,47.2357845,1
306,47.82986,0
307,48.13368,0
308,48.22049,0
309,48.26389,0
310,48.26389,0
311,48.35069,0
312,48.17708,0
313,47.82986,0
314,47.2357845,1
315,46.65799,0
316,46.70139,0
317,46.65799,0
318,47.2357845,1
319,46.65799,0
320,47.2357845,1
321,46.70139,0
322,46.52777,0
323,47.2357845,1
324,46.65799,0
325,46.44097,0
326,46.31076,0
327,47.13541,0
328,47.2357845,1
329,47.39583,0
330,47.56944,0
331,47.2357845,1
332,47.48264,0
333,47.2357845,1
334,47.82986,0
335,47.74305,0
336,47.96007,0
337,47.2357845,1
338,47.61285,0
339,47.35243,0
340,47.52604,0
341,47.09201,0
342,47.2357845,1
343,46.875,0
344,47.2357845,1
345,47.2357845,1
346,46.26736,0
347,47.09201,0
348,46.44097,0
349,46.22396,0
350,47.2357845,1
351,46.44097,0
352,46.44097,0
353,46.31076,0
354,46.39757,0
355,46.39757,0
356,46.44097,0
357,46.26736,0
358,47.48264,0
359,47.26563,0
360,46.74479,0
361,46.74479,0
362,46.39757,0
363,46.8316,0
364,47.2357845,1
365,47.52604,0
366,47.52604,0
367,47.2357845,1
368,47.2357845,1
369,47.74305,0
370,47.56944,0
371,47.56944,0
372,47.48264,0
373,47.2357845,1
374,47.61285,0
375,47.61285,0
376,47.61285,0
377,47.78646,0
378,47.61285,0
379,47.52604,0
380,47.69965,0
381,47.52604,0
382,47.13541,0
383,47.04861,0
384,46.70139,0
385,46.26736,0
386,46.18055,0
387,47.61285,0
388,47.39583,0
389,47.74305,0
390,46.74479,0
391,47.52604,0
392,46.31076,0
393,47.2357845,1
394,46.9618,0
395,47.35243,0
396,47.13541,0
397,47.2357845,1
398,47.48264,0
399,47.2357845,1
400,46.70139,0
401,46.90056479,1
402,46.52777,0
403,46.875,0
404,46.61458,0
405,47.39583,0
406,46.90056479,1
407,46.90056479,1
408,46.90056479,1
409,46.90056479,1
410,47.04861,0
411,47.13541,0
412,47.04861,0
413,47.17882,0
414,47.17882,0
415,47.26563,0
416,47.04861,0
417,46.90056479,1
418,47.13541,0
419,47.30902,0
420,47.61285,0
421,46.90056479,1
422,47.17882,0
423,47.56944,0
424,46.90056479,1
425,47.26563,0
426,47.26563,0
427,47.17882,0
428,47.17882,0
429,47.17882,0
430,47.30902,0
431,46.90056479,1
432,47.13541,0
433,46.90056479,1
434,47.17882,0
435,47.26563,0
436,47.35243,0
437,46.90056479,1
438,47.09201,0
439,46.90056479,1
440,47.13541,0
441,47.04861,0
442,46.90056479,1
443,47.26563,0
444,46.90056479,1
445,47.35243,0
446,47.30902,0
447,46.90056479,1
448,47.26563,0
449,46.90056479,1
450,47.09201,0
451,46.70139,0
452,46.65799,0
453,46.65799,0
454,46.90056479,1
455,46.44097,0
456,46.90056479,1
457,46.8316,0
458,46.90056479,1
459,46.52777,0
460,46.61458,0
461,46.65799,0
462,46.61458,0
463,46.61458,0
464,46.90056479,1
465,46.70139,0
466,46.8316,0
467,46.8316,0
468,46.90056479,1
469,46.90056479,1
470,47.04861,0
471,46.9618,0
472,46.875,0
473,46.875,0
474,46.9618,0
475,46.875,0
476,46.70139,0
477,46.9184,0
478,46.90056479,1
479,47.09201,0
480,47.17882,0
481,47.17882,0
482,47.04861,0
483,46.90056479,1
484,46.875,0
485,46.90056479,1
486,46.9618,0
487,46.26736,0
488,46.90056479,1
489,46.26736,0
490,46.22396,0
491,46.31076,0
492,46.26736,0
493,46.31076,0
494,46.18055,0
495,46.90056479,1
496,46.26736,0
497,46.31076,0
498,46.18055,0
499,46.22396,0

掩码表示是否需要恢复数据。”“1”表示需要恢复


我已经打印了模型的梯度。不存在消失梯度。太奇怪了。 非常感谢


Tags: importoutputsizetimeinittfnnvariable