使用tf.keras.layers.concatenate()作为tensorflow中的自定义层

2024-10-03 06:31:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用tensorflow中的自定义层制作U-net。我需要使用tf.keras.layers.concatenate,这就是我的问题。为我可以在方法调用中添加到层的所有其他层输入张量。但是连接层的语法是tf.keras.layers.concatenate(输入,axis),我需要类似于tf.keras.layers.concatenate(axis)(输入)的东西,但它不起作用。有人能帮我吗?
多谢各位

我的代码是这样的:

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.block1 = Conv2D(.....)
    self.block2 = BatchNormalization()
    ....etc.....
    self.decoder_concat = tf.keras.layers.concatenate(axis=-1) #that i need but it does not work

  def call(self, inputs):
     x = self.block1(inputs)
     x = self.block2(x)
     ....etc......
     x = self.decoder_concat([x, concatLayer]) #that i need but it does not work

Tags: selfthatinitlayerstfdefetcmymodel
1条回答
网友
1楼 · 发布于 2024-10-03 06:31:22

在这里提供解决方案(答案部分),即使它出现在评论部分,也是为了社区的利益

tf.keras.layers.concatenate更改为tf.keras.layers.Concatenate后,问题已得到解决

tf.keras.layers.Concatenate,用作连接Tensorflow中输入列表的层,其中astf.keras.layers.concatenate充当连接层的功能接口。请参阅更多详情here

请参阅下面更新的代码

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.block1 = Conv2D(.....)
    self.block2 = BatchNormalization()
    ....etc.....
    self.decoder_concat = tf.keras.layers.Concatenate(axis=-1) #that i need but it does not work

  def call(self, inputs):
     x = self.block1(inputs)
     x = self.block2(x)
     ....etc......
     x = self.decoder_concat([x, concatLayer]) #that i need but it does not work

相关问题 更多 >