从PyTorch的BiLSTM(BiGRU)获取最后一个状态

2024-09-26 22:54:01 发布

您现在位置:Python中文网/ 问答频道 /正文

在阅读了几篇文章之后,我仍然对从BiLSTM获取最后隐藏状态的实现的正确性感到困惑。

  1. Understanding Bidirectional RNN in PyTorch (TowardsDataScience)
  2. PackedSequence for seq2seq model (PyTorch forums)
  3. What's the difference between “hidden” and “output” in PyTorch LSTM? (StackOverflow)
  4. Select tensor in a batch of sequences (Pytorch formums)

来自最后一个源(4)的方法对我来说似乎是最干净的,但是我仍然不确定我是否正确理解了线程。我是否使用了LSTM和反向LSTM中正确的最终隐藏状态?这是我的实现

# pos contains indices of words in embedding matrix
# seqlengths contains info about sequence lengths
# so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and 
# seqlengths contains [3,2], we have batch with samples
# of variable length [4,6,9] and [3,1]

all_in_embs = self.in_embeddings(pos)
in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0))
output,lasthidden = self.rnn(in_emb_seqs)
if not self.data_processor.use_gru:
    lasthidden = lasthidden[0]
# u_emb_batch has shape batch_size x embedding_dimension
# sum last state from forward and backward  direction
u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:]

对吗?


Tags: andofinposselfforoutput状态

热门问题