在TensorFlow Fold中实现TreeLSTM的NaryTreeLSTM版本

2024-06-25 06:57:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图通过使用TensorFlow Fold实现本文TreeLSTM。实际上,在Tensorflow Fold中,已经有treelsm的一个例子,但是在BinaryTreeLSTM版本中,这里有一个教程:https://github.com/tensorflow/fold/blob/master/tensorflow_fold/g3doc/sentiment.ipynb

我现在要做的是实现一个真正的narytreestm,这意味着LSTM节点可以是任意数量子节点的父节点,而不是像上面的教程中那样只有2个。在

这是我试图折叠树的尝试,这是上面示例中logits_and_state()的修改版本。”

 def logits_and_state():
  """Creates a block that goes from tokens to (logits, state) tuples."""
  word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >>
              td.Scalar('int32') >> word_embedding)

  children_num = 
  children2vec_list = list()
  children2vec_list.append(embed_subtree())
  for i in range(children_num):
    children2vec_list.append(embed_subtree())

  children2vec = tuple(children2vec_list)

  # Trees are binary, so the tree layer takes two states as its input_state.

  zero_state = td.Zeros((tree_lstm.state_size,) * 2)
  # Input is a word vector.
  zero_inp = td.Zeros(word_embedding.output_type.shape[0])

  # word_case = 
  word_case = td.AllOf(word2vec, zero_state)
  children_case = td.AllOf(zero_inp, children2vec)

  tree2vec = td.OneOf(lambda x: 1 if len(x) == 1 else 2), [(1,word_case),(2,children_case)])
  return tree2vec >> tree_lstm >> (output_layer, td.Identity())

children_num是我目前正在努力解决的问题,我不知道要得到那个数字,尽管我知道可以通过td.GetItem(1)==>得到子元素的长度,但它将生成一个包含children数组的块==>;如何求出该块的实数?在

您可能会说,我应该尝试PyTorch或其他一些DL框架,它们也提供了动态计算图,但是在我的例子中,需求对Tensorflow Fold是严格的。在


Tags: tree节点tensorflowfoldnumlist例子word