我想在Keras中编写一个自定义LSTM层,在那里我想修改cell states
。我想我需要使用方法compute_output_shape
来指定状态的形状。但是,当我尝试更改输出形状时,什么也没有发生。你知道吗
下面的代码将输出形状更改为错误的值。我会期待代码给我的错误,但它是运行文件。你知道吗
import tensorflow as tf
from tensorflow.keras.layers import LSTM, LSTMCell, Input, RNN
from tensorflow.keras.models import Model
class HRU_LSTMCell(LSTM):
def __init__(self,
no_of_hrus=None,
**kwargs):
super(HRU_LSTMCell, self).__init__(**kwargs)
self.no_of_hrus = no_of_hrus
def build(self, input_shape):
super(HRU_LSTMCell, self).build(input_shape)
def compute_output_shape(self, input_shape):
print('calculating output shape')
return -99
# define model
inputs1 = Input(shape=(10, 3))
lstm1 = HRU_LSTMCell(units=5, return_state=True)(inputs1)
model = Model(inputs=inputs1, outputs=lstm1)
# check output
h1, h2, c = model.predict(data)
print(h1.shape, h2.shape, c.shape)
我错过了什么
目前没有回答
相关问题 更多 >
编程相关推荐