沿多个维度的Tensorflow argmax

2024-09-28 22:10:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我对tensorflow还不熟悉,我正在尝试得到张量中最大值的索引。代码如下:

def select(input_layer):

    shape = input_layer.get_shape().as_list()

    rel = tf.nn.relu(input_layer)
    print (rel)
    redu = tf.reduce_sum(rel,3)
    print (redu)

    location2 = tf.argmax(redu, 1)
    print (location2)

sess = tf.InteractiveSession()
I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32)
matI, matO = sess.run([I, select(I, 3)])
print(matI, matO)

输出如下:

^{pr2}$

由于argmax函数中的维数=1,Tensor("ArgMax:0") = (32,3)的形状。在应用argmax之前,有没有什么方法可以在不做reshape的情况下得到argmax输出张量size=(32,)?在


Tags: 代码layerinputtftensorflowselectrelsess
1条回答
网友
1楼 · 发布于 2024-09-28 22:10:14

您可能不希望输出的大小为(32,),因为当您沿着多个方向argmax时,您通常希望得到所有缩减维度的最大坐标。在您的例子中,您需要一个大小为(32,2)的输出。在

您可以这样做一个二维argmax

import numpy as np
import tensorflow as tf

x = np.zeros((10,9,8))
# pick a random position for each batch image that we set to 1
pos = np.stack([np.random.randint(9,size=10), np.random.randint(8,size=10)])

posext = np.concatenate([np.expand_dims([i for i in range(10)], axis=0), pos])
x[tuple(posext)] = 1

a = tf.argmax(tf.reshape(x, [10, -1]), axis=1)
pos2 = tf.stack([a // 8, tf.mod(a, 8)]) # recovered positions, one per batch image

sess = tf.InteractiveSession()
# check that the recovered positions are as expected
assert (pos == pos2.eval()).all(), "it did not work"

相关问题 更多 >