AdamOptimizer违抗tf.control\u依赖项

2024-05-03 16:57:58 发布

您现在位置:Python中文网/ 问答频道 /正文

不知何故,AdamOptimizer无视tf.control_dependencies。你知道吗

这是一个测试。我要求TensorFlow执行以下操作:

  1. 计算损失
  2. 计算损失
  3. 跑亚当的一步

我使用tf.control_dependencies来“确保”TF在运行步骤2之后运行步骤3。你知道吗

如果TensorFlow按正确的顺序执行这3个步骤,则步骤1和步骤2的结果应该相同。你知道吗

但事实并非如此。怎么了?你知道吗


测试:

import numpy as np
import tensorflow as tf

x = tf.get_variable('x', initializer=np.array([1], dtype=np.float64))
loss = x * x

optim = tf.train.AdamOptimizer(1)

## Control Dependencies ##
with tf.control_dependencies([loss]):
    train_op = optim.minimize(loss)

## Run ##
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    for i in range(1000):
        a = sess.run([loss])
        b = sess.run([loss, train_op])[0]
        print(a, b)
        assert np.allclose(a, b)

结果:

[array([1.])] [2.50003137e-14]
AssertionError

步骤1和步骤2的结果不一样。你知道吗


Tags: runimporttftensorflowasnp步骤dependencies
2条回答

听起来您希望sess.run([loss, adam_op])运行loss,然后运行adam_op。唉,赛斯·润不是那样的。以这个简单的示例为例,它打印1.0 1.0,表示set_x操作在get_x之前运行。你知道吗

import tensorflow as tf

var_x = tf.get_variable("x", shape=[], initializer=tf.zeros_initializer())
get_x = var_x.read_value()
set_x = var_x.assign(1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    a, b = sess.run([get_x, set_x])
    print(a, b)

根据tf.identity(loss)进行步骤3可以神奇地解决问题。你知道吗

怎么回事??你知道吗

魔法修复:

import numpy as np
import tensorflow as tf

x = tf.get_variable('x', initializer=np.array([1], dtype=np.float64))
loss = x * x

optim = tf.train.AdamOptimizer(1)

## Control Dependencies ##
loss2 = tf.identity(loss)  # < - this #
with tf.control_dependencies([loss2]):
    train_op = optim.minimize(loss)

## Run ##
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    for i in range(1000):
        a = sess.run([loss])
        b = sess.run([loss2, train_op])[0]  # < - loss2
        print(a, b)
        assert np.allclose(a, b)

结果:

[array([1.])] [1.]
[array([2.50003137e-14])] [2.50003137e-14]
[array([0.4489748])] [0.4489748]
...
[array([1.151504e-47])] [1.151504e-47]
[array([4.90468459e-46])] [4.90468459e-46]

相关问题 更多 >