如何修改Keras中自定义递归层的输出形状?

2024-09-28 01:29:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在Keras中编写一个自定义LSTM层,在那里我想修改cell states。我想我需要使用方法compute_output_shape来指定状态的形状。但是,当我尝试更改输出形状时,什么也没有发生。你知道吗

下面的代码将输出形状更改为错误的值。我会期待代码给我的错误,但它是运行文件。你知道吗

import tensorflow as tf
from tensorflow.keras.layers import LSTM, LSTMCell, Input, RNN
from tensorflow.keras.models import Model


class HRU_LSTMCell(LSTM):

    def __init__(self,
                 no_of_hrus=None,
                 **kwargs):

        super(HRU_LSTMCell, self).__init__(**kwargs)
        self.no_of_hrus = no_of_hrus


    def build(self, input_shape):
        super(HRU_LSTMCell, self).build(input_shape)


    def compute_output_shape(self, input_shape):
        print('calculating output shape')
        return -99 


# define model
inputs1 = Input(shape=(10, 3))
lstm1 = HRU_LSTMCell(units=5, return_state=True)(inputs1)
model = Model(inputs=inputs1, outputs=lstm1)

# check output
h1, h2, c = model.predict(data)

print(h1.shape, h2.shape, c.shape)

我错过了什么


Tags: ofnoimportselfinputoutputmodeltensorflow

热门问题