<p>我遇到了一个非常类似的问题我有一个定制的<code>Layer</code>,它</p>
<ul>
<li>当它在<code>keras</code>中时工作非常好(keras.version=2.24)使用tf后端(tf.version=1.14条)</li>
<li>当我切换到<code>tf.keras</code>时,引发与问题中描述的相同的错误消息(tf.version=1.14条)</li>
</ul>
<p>@Alexandre Passos的建议很管用,但我花了一段时间才弄明白他到底是什么意思。一、 因此,请在下面分享我的经验,以帮助未来的朋友。在</p>
<p>下面是一个定制层的示例</p>
<pre class="lang-py prettyprint-override"><code>
class CustomizedLayer(Layer) :
def __init__(self, num_feats, **kwargs) :
self.num_feats = num_feats
super(CustomizedLayer, self).__init__(**kwargs)
def build(self, input_shape) :
weight_shape = (input_shape[-1], self.num_feats) # < - problem
self.weight = self.add_weight(shape=weight_shape,
...)
super(CustomizedLayer, self).build(input_shape)
def call(self, inputs) :
...
</code></pre>
<p>问题发生在<code>tf.keras</code>,因为传递的<code>input_shape</code>是<code>tf.Dimension</code>类型,而不是{<cd7>}类型,这是{<cd8>}所必需的,也是使用<code>keras</code>时传递的类型。在</p>
<p>因此,同一层定义在<code>keras</code>中工作,而在<code>tf.keras</code>中不起作用。在</p>
<p>因此,您需要做的就是将<code>tf.Dimension</code>对象转换为<code>int</code>,即将<code>weight_shape</code>的行重写为</p>
^{pr2}$