重用嵌入变量进行推理Tf.估计器美国石油学会

2024-10-16 20:49:58 发布

您现在位置:Python中文网/ 问答频道 /正文

在使用seq2seq体系结构的NMT中,在推理过程中,我们需要在训练阶段训练的嵌入变量作为GreedyEmbeddingHelper或BeamSearchDecoder的输入。在

问题是,在使用估计器API进行训练和推断的上下文中,如何提取经过训练的嵌入变量用于预测?在


Tags: api过程体系结构阶段经过训练seq2seqnmtbeamsearchdecoder
1条回答
网友
1楼 · 发布于 2024-10-16 20:49:58

我根据下面的stackoverflowanswer找到了一个解决方案。对于预测阶段,可以使用tf.contrib.框架.load_variable从训练和保存的Tensorflow模型中检索嵌入变量,如下所示:

if mode == tf.estimator.ModeKeys.PREDICT:
    embeddings = tf.constant(tf.contrib.framework.load_variable('.','embed/embeddings'))
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embeddings,
    start_tokens=tf.fill([batch_size], 1),end_token=0)

所以在我的例子中,我运行的代码来自包含保存模型的同一个文件夹,我的变量名是“embed/embedding”。通过这个模型训练的Tensordings只适用于这个模型。否则,请参考上面链接的答案。在

要使用estimator API查找变量名,可以使用get_variable_names()方法获取保存在图中的所有变量名的列表。在

相关问题 更多 >