如何在tensorflow中交替训练op?

2024-09-30 04:27:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在实施交替训练计划。该图包含两个训练操作。培训应该在这两者之间交替进行。在

这与this或{a2}等研究相关

下面是一个小例子。但它似乎每一步都在更新两个操作。我怎样才能明确地在这两者之间交替?在

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# Import data
mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]), name='weights')
b = tf.Variable(tf.zeros([10]), name='biases')
y = tf.matmul(x, W) + b

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
global_step = tf.Variable(0, trainable=False)

tvars1 = [b]
train_step1 = tf.train.GradientDescentOptimizer(0.5).apply_gradients(zip(tf.gradients(cross_entropy, tvars1), tvars1), global_step)
tvars2 = [W]
train_step2 = tf.train.GradientDescentOptimizer(0.5).apply_gradients(zip(tf.gradients(cross_entropy, tvars2), tvars2), global_step)
train_step = tf.cond(tf.equal(tf.mod(global_step,2), 0), true_fn= lambda:train_step1, false_fn=lambda : train_step2)


sess = tf.InteractiveSession()
tf.global_variables_initializer().run()


# Train
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i % 100 == 0:
        print(sess.run([cross_entropy, global_step], feed_dict={x: mnist.test.images,
                                         y_: mnist.test.labels}))

这导致

^{pr2}$

全局步骤迭代到1802,因此每次调用train_step时都会执行两个train操作。(例如,当always false条件为tf.equal(global_step,-1)时,也会发生这种情况。)

我的问题是如何在执行train_step1和{}之间交替执行?在


Tags: inputdatatftensorflowstepbatchtrainvariable
1条回答
网友
1楼 · 发布于 2024-09-30 04:27:05

我认为最简单的方法就是

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  if i % 2 == 0:
    sess.run(train_step1, feed_dict={x: batch_xs, y_: batch_ys})
  else:
    sess.run(train_step2, feed_dict={x: batch_xs, y_: batch_ys})

但如果需要通过tensorflow条件流进行切换,请按以下方式进行:

^{pr2}$

相关问题 更多 >

    热门问题