我已经粘贴了代码的一部分(所有的部分是指张力索板部分)如下。我只记录了损失标量变量,只为一个epoch添加一次摘要。我一共运行了3个时期。理想情况下应该是一个非常小的tfevents文件。但是,tfevents文件是1.3GB。我不知道是什么原因导致文件这么大。 如果需要的话,很乐意分享剩下的代码
def do_training(update_op, loss, summary_op):
writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
step = 0
while True:
if step % (X_train.shape[0]/batch_size) == 0:
_, loss_value = sess.run((update_op, loss))
summary = sess.run(summary_op)
writer.add_summary(summary, global_step=step)
print('Step {} with loss {}'.format(step, loss_value))
else:
_, loss_value = sess.run((update_op, loss))
step += 1
except tf.errors.OutOfRangeError:
# we're through the dataset
pass
writer.close()
saver.save(sess,save_path)
print('Final loss: {}'.format(loss_value))
def serial_training(model_fn, dataset):
iterator = dataset.make_one_shot_iterator()
loss = model_fn(lambda: iterator.get_next())
tf.summary.scalar("loss", loss)
summary_op = tf.summary.merge_all()
optimizer = tf.train.AdamOptimizer(learning_rate=0.0002)
global_step = tf.train.get_or_create_global_step()
update_op = optimizer.minimize(loss, global_step=global_step)
do_training(update_op, loss, summary_op)
tf.reset_default_graph()
serial_training(training_model,training_dataset(epochs=3,batch_size=batch_size))
目前没有回答
相关问题 更多 >
编程相关推荐