在每个时间步使用LSTM的单元状态“可视化[和理解]循环网络”

2024-09-28 03:24:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我目前正试图以编程的方式将深度学习方法学习/训练我的问题的方式形象化,例如,偶然发现了Andrej Karpathy、Justin Johnson和Li Fei(例如,https://arxiv.org/abs/1506.02078)撰写的论文“形象化和理解循环网络”

在图2中,他们计算了预测过程中每个时间段LSTM的内存/单元状态(c)的tanh。 我现在想对Keras中的LSTM做同样的事情(或者说PyTorch,但是我对它们的API还很陌生)。API允许我使用参数return_sequences=True获取每个时间步的隐藏状态(h),但仅使用参数return_state=True获取最后一个单元格状态

根据我对Keras API的理解,我目前创建了两个模型:

inputs = Input(shape=(MAX_LEN,))

embd = Embedding(input_dim=MAX_FEATURES, output_dim=128, input_length=MAX_LEN)(inputs)
lstm = LSTM(units=128, return_sequences=True)(embd)

# 'ignore' return_sequences=True in training model
flt = Flatten()(lstm)

drp = Dropout(0.5)(flt)
dns = Dense(N_CLASSES)(drp)
activ = Activation('softmax')(dns)

training_model = Model([inputs], [activ])
state_model = Model([inputs], [activ, lstm])

training_model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
training_model.fit(X_train, y_train)

predictions, hidden_states = state_model.predict(X_test)

如何在每个时间步(在我的例子中是字符)对单元格状态执行相同的操作-最好不必重写/重新实现Keras/PyTorch源代码的“主要”部分? 我发现了这个answer,但还没有让它发挥作用,而且由于我没有足够的声誉,我无法在评论中添加任何问题(我看到了另一个答案,但我想防止“完全”重新实现Keras'LSTMCell)

编辑#1:我对PyTorch的API还很陌生,我该怎么做


Tags: apitruemodelreturn状态trainingpytorchmax

热门问题