擅长:python、mysql、java
<p>简而言之,这是因为从_生成器可以展平NumPy数组而不是张量。在</p>
<p>下面是一个较短的代码,将重现错误:</p>
<pre><code>import tensorflow as tf
import numpy as np
print(tf.__version__)
def g():
img = tf.random_uniform([3])
# img = np.random.rand(3)
# img = tf.convert_to_tensor(img)
yield img
dataset = tf.data.Dataset.from_generator(g, tf.float64, tf.TensorShape([3]))
iterator = dataset.make_one_shot_iterator()
next_iterator = iterator.get_next()
sess = tf.Session()
sess.run(next_iterator)
</code></pre>
<p>版本1.14中的错误消息非常有用。(确切的代码行会因版本不同而改变,但我已经检查了1.12和1.13,原因是一样的。)</p>
^{pr2}$
<p>当生成的元素是张量时,from_generator将把它展平为<code>output_types</code>。转换功能不起作用。在</p>
<p>要解决这个问题,只需在生成器生成张量时不要使用<code>from_generator</code>。您可以使用<code>from_tensors</code>或<code>from_tensor_slices</code>。在</p>
<pre><code>img = tf.random_uniform([3])
dataset = tf.data.Dataset.from_tensors(img).repeat()
iterator = dataset.make_initializable_iterator()
next_iterator = iterator.get_next()
sess = tf.Session()
sess.run(iterator.initializer)
sess.run(next_iterator)
</code></pre>