如何处理LSTMCell中的“输入形状轴1必须等于xx”

2024-10-04 11:35:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我创建了自己的网络模型,并使用维度为[batchez,10,8]的数据来训练这个模型

之后,我想使用维度为[1,40,8]的张量来运行net.predict(x),('x'是张量的名称),但我得到了错误:

Input shape axis 1 must equal 10, got shape [1,40,8]

在我看来,轴1只影响对LSTMCell的调用数量。为什么它应该与10相同?我如何处理这个问题

同时,我还在网络中创建了一个变量来确定它是否正在训练,因为我只想调用LSTMCell一次,以获得输出结果,而它不是在训练。然而,由于上述问题,我似乎无法实现这个目标

所以请帮帮我

这是代码

class lstm_rnn(keras.Model):
    def __init__(self, units):

        super(lstm_rnn, self).__init__()
        
  
        self.state0 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
        self.state1 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
        
     
        self.lstm_cell0 = layers.LSTMCell(units, dropout = 0.5)
        self.lstm_cell1 = layers.LSTMCell(units, dropout = 0.5)
    
    def call(self, inputs):
        x = inputs
        real_out = 0
        
  
        axis_size = x.shape[1]
        is_training = True
        if(axis_size == 2):
            is_training = False
        
        print(inputs.shape)
        
       
        if(is_training):
            state0 = self.state0
            state1 = self.state1
            step_cnt = 0
            
            for word in tf.unstack(x, axis = 1):
                out0, state0 = self.lstm_cell0(word, state0)
                out1, state1 = self.lstm_cell1(out0, state1)
                if(step_cnt == 0):
                    real_out = tf.reshape(out1, shape = (20, 1 ,6))
                else:
                    real_out = tf.concat([real_out, tf.reshape(out1, shape = (20, 1, 6))], axis = 1)
                step_cnt = step_cnt + 1
      
        else:
            state0 = [inputs[0], inputs[0]]
            state1 = [inputs[0], inputs[0]]
            info = inputs[1]
            out0, state0 = self.lstm_cell0(info, state0)
            out1, state1 = self.lstm_cell1(out0, state1)
            real_out = out1
        
        return real_out

Tags: selftfzerosoutrealinputsunitsshape