为什么tf.argmax()返回错误的索引?

2024-10-01 04:48:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用tf.argmax()函数获取logit的最大索引。我的代码如下所示:

import tensorflow as tf
import numpy as np

   
logits = tf.random_uniform([1,3,3,21], maxval=255, dtype=tf.float32, seed=0)
logits = logits / tf.norm(logits)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    logits_eval = sess.run(logits)
    logits_argmax_np = np.argmax(logits_eval, axis=-1)

    ypredT = tf.argmax(logits, axis=-1)
    logits_argmax_tf = ypredT.eval()
   

我可以使用np.argmax()获得正确的索引,但我不知道为什么tf.argmax()返回了错误的索引。
先谢谢你

编辑:我使用的是tensorflow 1.13


Tags: 函数runimporttftensorflowasevalnp
1条回答
网友
1楼 · 发布于 2024-10-01 04:48:43

这里的问题是,每次调用sess.run时,都是在自下而上单独执行会话。由于生成的是随机数,因此每次运行不会产生相同的结果,因此每次运行的argmax不同。但他们正在做同样的事情

要了解这一点,您可以从同一个会话执行中获取两个argmax,使用方括号从同一个sess获取ypredT张量和logits张量。运行:

# tensorflow graph
logits = tf.random_uniform([1,3,3,21], maxval=255, dtype=tf.float32, seed=0) 
logits = logits / tf.norm(logits)
ypredT = tf.argmax(logits, axis=-1) # tensorflow argmax

# run session 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    logits_eval, logits_argmax_tf = sess.run([logits, ypredT])

# after session has closed
logits_argmax_np = np.argmax(logits_eval, axis=-1) # get numpy argmax

print(logits_argmax_np)
print(logits_argmax_tf)

输出:

[[[14  0  4]
  [ 0  8  0]
  [10 12  3]]]
[[[14  0  4]
  [ 0  8  0]
  [10 12  3]]]

相关问题 更多 >