定义tf.keras中的模型块

2024-10-06 12:23:17 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在试验我的模型的架构,我想有几个预定义的层块,我可以随意混合。我认为为每个块结构创建一个不同的类会更容易,我认为在克拉斯特遣部队是一条路要走。所以我做了以下(玩具的例子,但很长。对不起!)。你知道吗

class PoolingBlock(Model):
    def __init__(self, filters, stride, name):
        super(PoolingBlock, self).__init__(name=name)

        self.bn = BatchNormalization()
        self.conv1 = Conv1D(filters=filters, kernel_size=1, padding='same')
        self.mp1 = MaxPooling1D(stride, padding='same')

    def call(self, input_tensor, training=False, mask=None):
        x = self.bn(input_tensor)
        x = tf.nn.relu(x)
        x = self.conv1(x)
        x = self.mp1(x)
        return x

class ModelA(Model):
    def __init__(self, n_dense, filters, stride, name):
        super(ModelA, self).__init__(name=name)

        self.d1 = Dense(n_dense, "DenseLayer1")
        self.pb1 = PoolingBlock(filters, stride, name="PoolingBlock_1")
        self.d2 = Dense(n_dense, "DenseLayer2")

    def call(self, inputs, training=False, mask=None):
        x = inputs
        x = self.d1(x)
        x = self.pb1(x)
        x = self.d2(x)
        return x

model = ModelA(100, 10, 2, 'ModelA')
model.build(input_shape=x.shape)

然后像往常一样继续model.compile(...)model.fit(...)。但在训练时,我收到这样的警告:

WARNING:tensorflow:Entity < bound method PoolingBlock.call of < model.PoolingBlock object at 0x7fe09ca04208 > > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output. Cause: converting < bound method PoolingBlock.call of < model.PoolingBlock object at 0x7fe09ca04208 > >: AttributeError: module 'gast' has no attribute 'Num'

我不明白那是什么意思。我想知道我的训练是否如我所计划的那样进行,这种子类化的方法是否正确可靠,我是否能够以某种方式抑制这个警告。你知道吗


Tags: thenameselfinputmodelinitdefcall