我正在尝试实现一个小调整版本的批处理规范化操作;在这个版本中,我需要显式地保留移动平均值,比如均值和方差。为了做到这一点,我正在对Tensorflow中的赋值和控制依赖机制进行一些实验,我遇到了一个神秘的问题。我有以下玩具代码;我在其中尝试测试tf.control_dependencies
是否按预期工作:
dataset = MnistDataSet(validation_sample_count=10000,
load_validation_from="validation_indices")
samples, labels, indices_list, one_hot_labels =
dataset.get_next_batch(batch_size=GlobalConstants.BATCH_SIZE)
samples = np.expand_dims(samples, axis=3)
flat_data = tf.contrib.layers.flatten(GlobalConstants.TRAIN_DATA_TENSOR)
mean = tf.Variable(name="mean", initial_value=tf.constant(100.0, shape=[784], dtype=tf.float32),
trainable=False, dtype=tf.float32)
a = tf.Variable(name="a", initial_value=5.0, trainable=False)
b = tf.Variable(name="b", initial_value=4.0, trainable=False)
c = tf.Variable(name="c", initial_value=0.0, trainable=False)
batch_mean, batch_var = tf.nn.moments(flat_data, [0])
b_op = tf.assign(b, a)
mean_op = tf.assign(mean, batch_mean)
with tf.control_dependencies([b_op, mean_op]):
c = a + b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
results = sess.run([c, mean], feed_dict={GlobalConstants.TRAIN_DATA_TENSOR: samples})
我只是加载一个数据批处理,每个条目有784个维度,计算它的矩并尝试将batch_mean
存储到变量mean
。我通常也将变量a
的值存储到b
中。在
在最后一行中,当我运行c
和mean
的值时,我看到c
为10,这是预期值。但是mean
仍然是100的向量,不包含批处理平均值。就像mean_op = tf.assign(mean, batch_mean)
没有执行一样。在
这是什么原因?据我所知,tf.control_dependencies
调用中的所有操作都必须在以下上下文中的任何操作之前执行;我在这里显式地调用c
,这是在上下文中。我错过什么了吗?在
这是
tf.Session.run()
的known "feature"。c
和mean
操作是独立的,因此mean
可以在c
之前计算(这将更新mean
)。在以下是这种效果的简短版本:
b
的第二次求值保证返回[3.0]
。但是第一个run
可能返回[2.0 3.0]
或{相关问题 更多 >
编程相关推荐