我想写一个神经网络,寻找一个x^2分布没有一个预定义的模型。准确地说,在[-1,1]中给它一些点,用它们的平方来训练,然后它就必须复制和预测类似的值,例如[-10,10]。 我或多或少做过——没有数据集。但后来我试图修改它以使用数据集并学习如何使用它。现在,我成功地使程序运行,但是输出比以前更差,主要是常数0。你知道吗
以前的版本类似于[-1,1]中的x^2,具有线性延长,这更好。。Previous output 现在蓝线是平的。我们的目标是和红色的一致。。你知道吗
这里的评论都是波兰语的,很抱歉。你知道吗
# square2.py - drugie podejscie do trenowania sieci za pomocą Tensorflow
# cel: nauczyć sieć rozpoznawać rozkład x**2
# analiza skryptu z:
# https://stackoverflow.com/questions/43140591/neural-network-to-predict-nth-square
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.python.framework.ops import reset_default_graph
# def. danych do trenowania sieci
# x_train = (np.random.rand(10**3)*4-2).reshape(-1,1)
# y_train = x_train**2
square2_dane = np.load("square2_dane.npz")
x_train = square2_dane['x_tren'].reshape(-1,1)
y_train = square2_dane['y_tren'].reshape(-1,1)
# zoptymalizować dzielenie danych
# x_train = square2_dane['x_tren'].reshape(-1,1)
# ds_x = tf.data.Dataset.from_tensor_slices(x_train)
# batch_x = ds_x.batch(rozm_paczki)
# iterator = ds_x.make_one_shot_iterator()
# określenie parametrów sieci
wymiary = [50,50,50,1]
epoki = 500
rozm_paczki = 200
reset_default_graph()
X = tf.placeholder(tf.float32, shape=[None,1])
Y = tf.placeholder(tf.float32, shape=[None,1])
weights = []
biases = []
n_inputs = 1
# inicjalizacja zmiennych
for i,n_outputs in enumerate(wymiary):
with tf.variable_scope("layer_{}".format(i)):
w = tf.get_variable(name="W", shape=[n_inputs,n_outputs],initializer = tf.random_normal_initializer(mean=0.0,stddev=0.02,seed=42))
b=tf.get_variable(name="b",shape=[n_outputs],initializer=tf.zeros_initializer)
weights.append(w)
biases.append(b)
n_inputs=n_outputs
def forward_pass(X,weights,biases):
h=X
for i in range(len(weights)):
h=tf.add(tf.matmul(h,weights[i]),biases[i])
h=tf.nn.relu(h)
return h
output_layer = forward_pass(X,weights,biases)
f_strat = tf.reduce_mean(tf.squared_difference(output_layer,Y),1)
f_strat = tf.reduce_sum(f_strat)
# alternatywna funkcja straty
#f_strat2 = tf.reduce_sum(tf.abs(Y-y_train)/y_train)
optimizer = tf.train.AdamOptimizer(learning_rate=0.003).minimize(f_strat)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# trenowanie
dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train))
dataset = dataset.batch(rozm_paczki)
dataset = dataset.repeat(epoki)
iterator = dataset.make_one_shot_iterator()
ds_x, ds_y = iterator.get_next()
sess.run(optimizer, {X: sess.run(ds_x), Y: sess.run(ds_y)})
saver = tf.train.Saver()
save = saver.save(sess, "./model.ckpt")
print("Model zapisano jako: %s" % save)
# puszczenie sieci na danych
x_test = np.linspace(-1,1,600)
network_outputs = sess.run(output_layer,feed_dict = {X :x_test.reshape(-1,1)})
plt.plot(x_test,x_test**2,color='r',label='y=x^2')
plt.plot(x_test,network_outputs,color='b',label='sieć NN')
plt.legend(loc='right')
plt.show()
我认为问题在于训练数据的输入
sess.run(optimizer, {X: sess.run(ds_x), Y: sess.run(ds_y)})
或者用dSux,dSuy的定义。这是我第一个这样的程序。。
这是行的输出(sees块的insead)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# trenowanie
for i in range(epoki):
idx = np.arange(len(x_train))
np.random.shuffle(idx)
for j in range(len(x_train)//rozm_paczki):
cur_idx = idx[rozm_paczki*j:(rozm_paczki+1)*j]
sess.run(optimizer,feed_dict = {X:x_train[cur_idx],Y:y_train[cur_idx]})
saver = tf.train.Saver()
save = saver.save(sess, "./model.ckpt")
print("Model zapisano jako: %s" % save)
谢谢!你知道吗
附言:我受到了Neural Network to predict nth square的极大启发
有两个问题共同导致您的模型精度较差,并且都涉及到这一行:
只执行一个训练步骤,因为此代码不在循环中。您的原始代码运行了
len(x_train)//rozm_paczki
个步骤,这应该会取得更大的进展。对
sess.run(ds_x)
和sess.run(ds_y)
的两个调用以不同的步骤运行,这意味着它们将包含来自不同批的不相关的值。对sess.run(ds_x)
或sess.run(ds_y)
的每次调用都将Iterator
移动到下一批,并丢弃在sess.run()
调用中未显式请求的输入元素的任何部分。本质上,您将从批处理i获得X
,从批处理i+1获得Y
(反之亦然),并且模型将在无效数据上训练。如果要从同一批中获取值,则需要在一个sess.run([ds_x, ds_y])
调用中完成。还有两个问题可能会影响效率:
Dataset
没有被洗牌。原始代码在每个时代开始时调用np.random.shuffle()
。您应该在dataset = dataset.repeat()
之前包含一个dataset = dataset.shuffle(len(x_train))
。将值从
Iterator
取回Python(例如,当您执行sess.run(ds_x)
)并将它们反馈到训练步骤时,效率很低。将Iterator.get_next()
操作的输出作为输入直接传递到前馈步骤更有效。把这些放在一起,这里是一个重写版本的程序,它解决了这四个问题,并获得了正确的结果。(不幸的是,我的波兰语不够好,无法保留评论,所以我把它翻译成了英语。)
相关问题 更多 >
编程相关推荐