擅长:python、mysql、java
<p>基于<a href="https://www.tensorflow.org/api_docs/python/tf/argmax" rel="nofollow noreferrer">documentation</a>,<code>tf.argmax()</code>接受一个<code>input</code>,以及一个{<cd3>},以及其他参数。在</p>
<p>如果您的标签的形状为[1,1],您希望从argmax跨轴1得到什么?只有一个条目。在</p>
<p>最有可能的情况是,您希望将标签与argmaxed结果进行比较。所以:</p>
<pre><code>...
image, label = ...
# label: Tensor("..", shape=(1, 1), dtype=int32)
logits = model(image)
# logits: Tensor("..", shape=(1, 10), dtype=float32)
predic = tf.nn.softmax(logits)
arg_log = tf.argmax(logits, 1)
...
pre, lbl, a_log, a_lbl = sess.run([predic, label, arg_log, arg_lbl])
cor_pre = tf.equal(arg_log, tf.cast(label, tf.int64))
</code></pre>