如何使用TensorFlow中的估计器将模型存储在“.pb”文件中?

2024-06-30 15:43:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我用张量流的估计器训练我的模型。似乎应该使用export_savedmodel来生成.pb文件,但我真的不知道如何构造serving_input_receiver_fn。有人有什么想法吗? 欢迎使用示例代码。在

额外问题:

  1. 当我想重新加载模型时,.pb是唯一需要的文件吗?Variable没有必要吗?

  2. 与使用adam优化器的.pb相比,.pb将减少多少模型文件大小?


Tags: 文件代码模型示例inputexportvariablefn
1条回答
网友
1楼 · 发布于 2024-06-30 15:43:19

您可以使用^{}.ckpt+.pbtxt生成{} 如果您使用的是^{},那么您将在^{}中找到这两个文件

python freeze_graph.py \
     input_graph=graph.pbtxt \
     input_checkpoint=model.ckpt-308 \
     output_graph=output_graph.pb
     output_node_names=<output_node>
  1. Is .pb the only file I need when I want to reload the model? Variable unnecessary?

是的,您必须知道您是模型的输入节点和输出节点名称。然后使用^{}加载.pb文件并使用^{}获得输入和输出操作

  1. How much will .pb reduced the model file size compared with .ckpt with adam optimizer?

pb文件不是压缩的.ckpt文件,因此没有“压缩率”。在

但是,有一种推理方法to optimize your .pb file,这种优化可以减少文件大小,因为它删除了图形中只用于训练的部分操作(请参见完整描述here)。在

[comment] how can I get the input and output node names?

可以使用opname参数设置输入和输出节点名称。在

要列出.pbtxt文件中的节点名,请使用以下脚本。在

^{pr2}$

[comment] I found that there is a tf.estimator.Estimator.export_savedmodel(), is that the function to store model in .pb directly? And I'm struggling in it's parameter serving_input_receiver_fn. Any ideas?

export_savedmodel()生成一个SavedModel,这是一种用于TensorFlow模型的通用序列化格式。它应该包含与TensorFlow Serving APIs相匹配的所有内容

serving_input_receiver_fn()是生成SavedModel所需内容的一部分,它通过向图中添加占位符来确定模型的输入签名。在

从医生那里

This function has the following purposes:

  • To add placeholders to the graph that the serving system will feed with inference requests.
  • To add any additional ops needed to convert data from the input format into the feature Tensors expected by the model.

如果您接收的是序列化的tf.Examples(这是一种典型模式)形式的推理请求,那么您可以使用doc中提供的示例。在

feature_spec = {'foo': tf.FixedLenFeature(...),
                'bar': tf.VarLenFeature(...)}

def serving_input_receiver_fn():
  """An input receiver that expects a serialized tf.Example."""
  serialized_tf_example = tf.placeholder(dtype=tf.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
  receiver_tensors = {'examples': serialized_tf_example}
  features = tf.parse_example(serialized_tf_example, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

[comment] Any idea to list the node names in '.pb'?

这取决于它是如何产生的。在

如果是^{}用法:

import tensorflow as tf

with tf.Session() as sess:
    meta_graph_def = tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        './saved_models/1519232535')
    print [n.name for n in meta_graph_def.graph_def.node]

如果是^{},则使用:

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    with gfile.FastGFile('model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        print [n.name for n in graph_def.node]

相关问题 更多 >