我创建了自己的网络模型,并使用维度为[batchez,10,8]的数据来训练这个模型
之后,我想使用维度为[1,40,8]的张量来运行net.predict(x)
,('x'是张量的名称),但我得到了错误:
Input shape axis 1 must equal 10, got shape [1,40,8]
在我看来,轴1只影响对LSTMCell的调用数量。为什么它应该与10相同?我如何处理这个问题
同时,我还在网络中创建了一个变量来确定它是否正在训练,因为我只想调用LSTMCell一次,以获得输出结果,而它不是在训练。然而,由于上述问题,我似乎无法实现这个目标
所以请帮帮我
这是代码
class lstm_rnn(keras.Model):
def __init__(self, units):
super(lstm_rnn, self).__init__()
self.state0 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
self.state1 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
self.lstm_cell0 = layers.LSTMCell(units, dropout = 0.5)
self.lstm_cell1 = layers.LSTMCell(units, dropout = 0.5)
def call(self, inputs):
x = inputs
real_out = 0
axis_size = x.shape[1]
is_training = True
if(axis_size == 2):
is_training = False
print(inputs.shape)
if(is_training):
state0 = self.state0
state1 = self.state1
step_cnt = 0
for word in tf.unstack(x, axis = 1):
out0, state0 = self.lstm_cell0(word, state0)
out1, state1 = self.lstm_cell1(out0, state1)
if(step_cnt == 0):
real_out = tf.reshape(out1, shape = (20, 1 ,6))
else:
real_out = tf.concat([real_out, tf.reshape(out1, shape = (20, 1, 6))], axis = 1)
step_cnt = step_cnt + 1
else:
state0 = [inputs[0], inputs[0]]
state1 = [inputs[0], inputs[0]]
info = inputs[1]
out0, state0 = self.lstm_cell0(info, state0)
out1, state1 = self.lstm_cell1(out0, state1)
real_out = out1
return real_out
目前没有回答
相关问题 更多 >
编程相关推荐