我正在尝试使用tf.argmax()
函数获取logit的最大索引。我的代码如下所示:
import tensorflow as tf
import numpy as np
logits = tf.random_uniform([1,3,3,21], maxval=255, dtype=tf.float32, seed=0)
logits = logits / tf.norm(logits)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
logits_eval = sess.run(logits)
logits_argmax_np = np.argmax(logits_eval, axis=-1)
ypredT = tf.argmax(logits, axis=-1)
logits_argmax_tf = ypredT.eval()
我可以使用np.argmax()
获得正确的索引,但我不知道为什么tf.argmax()
返回了错误的索引。
先谢谢你
编辑:我使用的是tensorflow 1.13
这里的问题是,每次调用
sess.run
时,都是在自下而上单独执行会话。由于生成的是随机数,因此每次运行不会产生相同的结果,因此每次运行的argmax不同。但他们正在做同样的事情要了解这一点,您可以从同一个会话执行中获取两个argmax,使用方括号从同一个sess获取
ypredT
张量和logits
张量。运行:输出:
相关问题 更多 >
编程相关推荐