tf.GradientTape()不适用于切片输出

2024-10-02 12:22:52 发布

您现在位置:Python中文网/ 问答频道 /正文

下面是我试图运行的一段代码:

import tensorflow as tf

a = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
b = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)

with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
    tape1.watch(a)
    tape2.watch(a)
    
    c = a * b

grad1 = tape1.gradient(c, a)
grad2 = tape2.gradient(c[:, 0], a)
print(grad1)
print(grad2)

这是输出:

tf.Tensor(
[[1. 2.]
 [2. 3.]], shape=(2, 2), dtype=float32)
None

正如您所观察到的,tf.GradientTape()无法处理切片输出。有什么办法吗


Tags: 代码importtfaswatchprintdtypeconstant
1条回答
网友
1楼 · 发布于 2024-10-02 12:22:52

是的,你对张量所做的一切都需要在磁带上下文中发生。您可以这样相对轻松地修复它:

import tensorflow as tf

a = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
b = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)

with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
    tape1.watch(a)
    tape2.watch(a)
    
    c = a * b
    c_sliced = c[:, 0]

grad1 = tape1.gradient(c, a)
grad2 = tape2.gradient(c_sliced, a)
print(grad1)
print(grad2)

相关问题 更多 >

    热门问题