ValueError:输入0与层批处理规范化\u 1不兼容:预期的ndim=3,找到的ndim=2

2024-09-30 14:22:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用DeepTriage的实现,这是一种用于bug筛选的深入学习方法。This website包括数据集、源代码和纸张。我知道这是一个非常具体的领域,但我会尽量简化。在

the source code中,他们定义了他们的方法“DBRNN-A:具有注意机制和长短期记忆单元(LSTM)的深度双向递归神经网络”,代码部分如下:

input = Input(shape=(max_sentence_len,), dtype='int32')
sequence_embed = Embedding(vocab_size, embed_size_word2vec, input_length=max_sentence_len)(input)

forwards_1 = LSTM(1024, return_sequences=True, dropout_U=0.2)(sequence_embed)
attention_1 = SoftAttentionConcat()(forwards_1)
after_dp_forward_5 = BatchNormalization()(attention_1)

backwards_1 = LSTM(1024, return_sequences=True, dropout_U=0.2, go_backwards=True)(sequence_embed)
attention_2 = SoftAttentionConcat()(backwards_1)
after_dp_backward_5 = BatchNormalization()(attention_2)

merged = merge([after_dp_forward_5, after_dp_backward_5], mode='concat', concat_axis=-1)
after_merge = Dense(1000, activation='relu')(merged)
after_dp = Dropout(0.4)(after_merge)
output = Dense(len(train_label), activation='softmax')(after_dp)                
model = Model(input=input, output=output)
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=1e-4), metrics=['accuracy']) 

SoftAttentionConcat实现来自here。其余函数来自keras。此外,在the paper中,它们共享的结构如下:

DBRNN-A

在第一个批处理规范化行中,它抛出以下错误:

^{pr2}$

当我使用max_sentence_len=50max_sentence_len=200时,我观察尺寸直到误差点,我看到以下形状:

Input               -> (None, 50)
Embedding           -> (None, 50, 200)
LSTM                -> (None, None, 1024)
SoftAttentionConcat -> (None, 2048) 

那么,有人看到问题了吗?在


Tags: nonetrueinputlenembedmergesentencemax
1条回答
网友
1楼 · 发布于 2024-09-30 14:22:52

我想问题是在Keras结构中使用tensorflow代码,或者是一些版本问题。在

通过使用问题和答案here,我在Keras中实现了注意机制,如下所示:

attention_1 = Dense(1, activation="tanh")(forwards_1)
attention_1 = Flatten()(attention_1)  # squeeze (None,50,1)->(None,50)
attention_1 = Activation("softmax")(attention_1)
attention_1 = RepeatVector(num_rnn_unit)(attention_1)
attention_1 = Permute([2, 1])(attention_1)
attention_1 = multiply([forwards_1, attention_1])
attention_1 = Lambda(lambda xin: K.sum(xin, axis=1), output_shape=(num_rnn_unit,))(attention_1)

last_out_1 = Lambda(lambda xin: xin[:, -1, :])(forwards_1)
sent_representation_1 = concatenate([last_out_1, attention_1])

这很有效。我用于实现的所有源代码都在GitHub中。在

相关问题 更多 >