Tensorflow如何使用批次维度执行tf.gather

2024-10-01 13:29:32 发布

您现在位置:Python中文网/ 问答频道 /正文

不幸的是,我不知道如何表述这个问题的标题,也许有人可以改变它

如何以优雅的方式替换以下for循环

#tensor.shape -> (batchsize,100)
#indices.shape -> (batchsize,100) 
liste = []
for i in range(tensor.shape[0]):
        liste.append(tf.gather(tensor[i,:], indices[i,:10]))

new_tensor = tf.stack(liste)
        

Tags: in标题newforstacktf方式range
1条回答
网友
1楼 · 发布于 2024-10-01 13:29:32

这应该可以做到:

new_tensor = tf.gather(tensor, axis=-1, indices=indices[:, :10], batch_dims=1)

这里有一个最小的可复制示例:

import tensorflow as tf

# for version 1.x
#tf.enable_eager_execution()

tensor = tf.random.normal((2, 10))
indices = tf.random.uniform(shape=[2, 10], minval=0, maxval=4, dtype=tf.int32)

liste = []
for i in range(tensor.shape[0]):
        liste.append(tf.gather(tensor[i,:], indices[i,:5]))

new_tensor = tf.stack(liste)

print('tensor: ')
print(tensor)

print('new_tensor: ')
print(new_tensor)

new_tensor_v2 = tf.gather(tensor, axis=-1, indices=indices[:, :5], batch_dims=1)
print('new_tensor_v2: ')
print(new_tensor_v2)

相关问题 更多 >