具有多输出形状的Tensorflow子类keras

2024-05-06 06:58:53 发布

您现在位置:Python中文网/ 问答频道 /正文

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(1)
        self.build(input_shape=[None, 1])

    def call(self, inputs, **kwargs):
        return self.dense(inputs)

MyModel().summary()

enter image description here

模型图也不起作用:

tf.keras.utils.plot_model(model, to_file='model_1.png', show_shapes=True)

我在几个tensorflow版本2.3.0、2.3.1和2.4.1上尝试了这段代码,每次output shape都是multiple!是虫子吗?有办法吗


Tags: selfmodelinitlayerstfdefclassmymodel
1条回答
网友
1楼 · 发布于 2024-05-06 06:58:53

这不是错误。通常,我们不能对子类模型的结构做任何假设。这就是为什么在模型子类API中.summary()无法获得与函数类顺序类API相同的输出形状

但这里有一个解决方法来实现这一点。您可以通过以下方法实现这一点

import tensorflow as tf 

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(1)
        self.build(input_shape=[None, 1])

    def call(self, inputs, **kwargs):
        return self.dense(inputs)

    def build_graph(self):
        x = tf.keras.layers.Input(shape=(1))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

MyModel().build_graph().summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 1)]               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 2         
=================================================================
Total params: 2
Trainable params: 2
Non-trainable params: 0
_________________________________________________________________

与绘制模型相同

tf.keras.utils.plot_model(
    MyModel().build_graph()                     
)

相关问题 更多 >