如何在Tensorflow 2.0中使用gradient_override_map?

2024-10-03 23:24:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我尝试在Tensorflow 2.0中使用gradient_override_map。这里有一个example in the documentation,我也将用它作为例子。在

{cd2>中的}可用于计算^ 2中的梯度:

import tensorflow as tf
print(tf.version.VERSION)  # 2.0.0-alpha0

x = tf.Variable(5.0)
with tf.GradientTape() as tape:
    s_1 = tf.square(x)
print(tape.gradient(s_1, x))

还有一个tf.custom_gradient修饰符,可用于为一个新的函数定义渐变(同样,使用example from the docs):

^{pr2}$

但是,我想替换标准函数的梯度,比如tf.square。我尝试使用以下代码:

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return tf.constant(0)

with tf.Graph().as_default() as g:
    x = tf.Variable(5.0)
    with g.gradient_override_map({"Square": "CustomSquare"}):
        with tf.GradientTape() as tape:
            s_2 = tf.square(x, name="Square")

    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())            
        print(sess.run(tape.gradient(s_2, x)))

但是,有两个问题:梯度替换似乎不起作用(它被计算为10.0而不是0.0),我需要借助于session.run()来执行图形。在“原生”TensorFlow 2.0中有没有实现这一点的方法?在

在TensorFlow 1.12.0中,以下内容生成所需的输出:

import tensorflow as tf
print(tf.__version__)  # 1.12.0

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return tf.constant(0)

x = tf.Variable(5.0)

g = tf.get_default_graph()
with g.gradient_override_map({"Square": "CustomSquare"}):
    s_2 = tf.square(x, name="Square")
grad = tf.gradients(s_2, x)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(grad))

Tags: runmaptfaswithsess梯度print
1条回答
网友
1楼 · 发布于 2024-10-03 23:24:36

TensorFlow 2.0中没有内置机制来覆盖范围内内置运算符的所有渐变。但是,如果能够修改对内置运算符的每个调用的调用位置,则可以使用tf.custom_gradient修饰符,如下所示:

@tf.custom_gradient
def custom_square(x):
  def grad(dy):
    return tf.constant(0.0)
  return tf.square(x), grad

with tf.Graph().as_default() as g:
  x = tf.Variable(5.0)
  with tf.GradientTape() as tape:
    s_2 = custom_square(x)

  with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.global_variables_initializer())            
    print(sess.run(tape.gradient(s_2, x)))

相关问题 更多 >