如何在急切模式下迭代tf.tensor

2024-05-26 00:33:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在渴望模式下迭代张量,但我不能

当然,你会做一些类似的事情:

probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1, 2, 3]))

@tf.function
def iterate_tensor(probs, indexs):
    return [output[label] for output, label in zip(probs, indexs)]
iterate_tensor(probs, indexs)

但这会产生错误OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed:

我试过的另一件事是:

probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1, 2, 3]))

@tf.function
def iterate_tensor(probs, indexs):
    return tf.map_fn(lambda i: i[0][i[1]], (probs, indexs), dtype=(tf.int64, tf.int64))

iterate_tensor(probs, indexs)

给出错误ValueError: The two structures don't have the same nested structure.


Tags: toconvertoutputreturntfdef错误np
1条回答
网友
1楼 · 发布于 2024-05-26 00:33:46

这似乎有效:

probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1,1,1]))

@tf.function
def iterate_tensor(probs, indexs):
    return tf.linalg.diag_part(tf.gather(probs, indexs, axis=1))
iterate_tensor(probs, indexs)

输出:<tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 5, 8])>

相关问题 更多 >