擅长:python、mysql、java
<p>这似乎是由于tf.map_fn无法确定输入张量的类型规范(当输入为keras辛同态输入时)</p>
<p>现在我已经有好几次遇到tensorflow ops和Keras符号输入张量的问题。在自定义层中包装有问题的代码似乎可以解决问题</p>
<p>例如,将上面的代码替换为该代码将成功执行:</p>
<pre><code>import tensorflow as tf
from tensorflow.keras.layers import Input, Layer
from tensorflow.keras.models import Model
x = Input(shape=(10,))
class MapLayer(Layer):
def call(self, input):
return tf.map_fn(lambda x : x * 2, input, fn_output_signature=tf.float32)
y = MapLayer()(x)
model = Model(inputs=x, outputs=y)
</code></pre>