我在试着评估单个图像。到目前为止还有效。我得到了每个类的概率和正确的标签。但是当我试图用tf.argmax(label, 1)
获得类时,我总是得到类“0”。在
...
image, label = ...
# label: Tensor("..", shape=(1, 1), dtype=int32)
logits = model(image)
# logits: Tensor("..", shape=(1, 10), dtype=float32)
predic = tf.nn.softmax(logits)
arg_log = tf.argmax(logits, 1)
arg_lbl = tf.argmax(label, 1)
...
pre, lbl, a_log, a_lbl = sess.run([predic, label, arg_log, arg_lbl])
print(pre)
# [[2.0451562e-06 # class 0
# 6.1964911e-06 # class 1
# 4.1852250e-06 # class 2
# 9.9847549e-01 # class 3 - We have a winner :)
# 8.2492170e-07 # class 4
# 3.1969071e-06 # class 5
# 1.5037126e-03 # class 6
# 1.6847488e-07 # class 7
# 6.7177882e-07 # class 8
# 3.4959594e-06]] # class 9
print(lbl)
# [[3]]
print(a_log)
# [3]
print(a_lbl)
# [0] # Why i dont get "3"?
...
对于每个数据点,我总是得到“0”。我想继续使用tf.equal()
,但是标签的argmax值错误,这当然是不可能的。有什么想法吗?公司名称:
我每次都得到“0”,因为我得到了索引!我把问题改为: 如何使用tf.argmax公司()在张量流中的形状=(1,1)的张量上? 收件人: 如何使用tf.相等()在TensorFlow中具有形状=(1,1)的标签张量?在
我找到了一个解决方案,将标签强制转换为
int64
,然后在tf.equal()
中使用它。在任何形状(1,1)数组都将只包含一个元素。该元素必须是数组中的最大元素。在
基于documentation,},以及其他参数。在
tf.argmax()
接受一个input
,以及一个{如果您的标签的形状为[1,1],您希望从argmax跨轴1得到什么?只有一个条目。在
最有可能的情况是,您希望将标签与argmaxed结果进行比较。所以:
相关问题 更多 >
编程相关推荐