张量流:tf.分配不分配任何内容

2024-10-02 04:18:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试实现一个小调整版本的批处理规范化操作;在这个版本中,我需要显式地保留移动平均值,比如均值和方差。为了做到这一点,我正在对Tensorflow中的赋值和控制依赖机制进行一些实验,我遇到了一个神秘的问题。我有以下玩具代码;我在其中尝试测试tf.control_dependencies是否按预期工作:

dataset = MnistDataSet(validation_sample_count=10000, 
load_validation_from="validation_indices")
samples, labels, indices_list, one_hot_labels = 
dataset.get_next_batch(batch_size=GlobalConstants.BATCH_SIZE)
samples = np.expand_dims(samples, axis=3)

flat_data = tf.contrib.layers.flatten(GlobalConstants.TRAIN_DATA_TENSOR)
mean = tf.Variable(name="mean", initial_value=tf.constant(100.0, shape=[784], dtype=tf.float32),
               trainable=False, dtype=tf.float32)
a = tf.Variable(name="a", initial_value=5.0, trainable=False)
b = tf.Variable(name="b", initial_value=4.0, trainable=False)
c = tf.Variable(name="c", initial_value=0.0, trainable=False)
batch_mean, batch_var = tf.nn.moments(flat_data, [0])

b_op = tf.assign(b, a)
mean_op = tf.assign(mean, batch_mean)
with tf.control_dependencies([b_op, mean_op]):
    c = a + b

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

results = sess.run([c, mean], feed_dict={GlobalConstants.TRAIN_DATA_TENSOR: samples})

我只是加载一个数据批处理,每个条目有784个维度,计算它的矩并尝试将batch_mean存储到变量mean。我通常也将变量a的值存储到b中。在

在最后一行中,当我运行cmean的值时,我看到c为10,这是预期值。但是mean仍然是100的向量,不包含批处理平均值。就像mean_op = tf.assign(mean, batch_mean)没有执行一样。在

这是什么原因?据我所知,tf.control_dependencies调用中的所有操作都必须在以下上下文中的任何操作之前执行;我在这里显式地调用c,这是在上下文中。我错过什么了吗?在


Tags: namefalsevaluetfbatchdependenciesmeanvariable
1条回答
网友
1楼 · 发布于 2024-10-02 04:18:38

这是tf.Session.run()known "feature"cmean操作是独立的,因此mean可以在c之前计算(这将更新mean)。在

以下是这种效果的简短版本:

a = tf.Variable(name="a", initial_value=1.0, trainable=False)
b = tf.Variable(name="b", initial_value=0.0, trainable=False)
dependent_op = tf.assign(b, a * 3)
with tf.control_dependencies([dependent_op]):
  c = a + 1

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run([c, b]))
  print(sess.run([b]))

b的第二次求值保证返回[3.0]。但是第一个run可能返回[2.0 3.0]或{}。在

相关问题 更多 >

    热门问题