如何在tensorflow中随机选择索引而不是最大值(tf.arg_max)

2024-10-01 07:43:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我可以使用tf.arg_max选择矩阵中的最大索引, 在tensorflow中是否有一个函数,用于在前n中随机选择索引

test = np.array([
[1, 2, 3],
 [2, 3, 4], 
 [5, 4, 3], 
 [8, 7, 2]])
argmax0 = tf.arg_max(test, 0)
# argmax0 = [3, 3, 1]

我需要一个函数为每个数组随机选择前2名中的索引。 例如: 第一个colmuns[1,2,5,8],top2是[5,8],只需从[5,8]中随机选择一个即可。 所以最终答案可能是[3,2,0],[2,2,0],[3,3,1],[3,2,0]或更多


Tags: 函数答案testtftensorflownparg矩阵
2条回答

您可以使用.argsort方法,然后从该方法中获取最前面的N

N = 3
a = [1, 3, 4, 5, 2]
top_N_indicies = a.argsort()[-N:][::-1]

# top_N_indicies = [3, 2, 1]

获取topk值:

values, _ = tf.math.top_k(test, 2)
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[3, 2],
       [4, 3],
       [5, 4],
       [8, 7]])>

洗牌每行中的值:

shuffled = tf.map_fn(tf.random.shuffle, values)
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[2, 3],
       [4, 3],
       [4, 5],
       [7, 8]])>

选择每个无序行的第一行:

tf.gather(shuffled, [0], axis=1)
<tf.Tensor: shape=(4, 1), dtype=int32, numpy=
array([[2],
       [4],
       [4],
       [7]])>

复制/可复制代码:

import tensorflow as tf
import numpy as np

test = np.array([
    [1, 2, 3],
    [2, 3, 4],
    [5, 4, 3],
    [8, 7, 2]])

values, _ = tf.math.top_k(test, 2)

shuffled = tf.map_fn(tf.random.shuffle, values)

tf.gather(shuffled, [0], axis=1)

相关问题 更多 >