如何在tensorflow图中查找操作名

2024-07-08 08:07:13 发布

您现在位置:Python中文网/ 问答频道 /正文

我在Python中使用Keras训练了一个模型,我想在Java程序中使用这个训练过的模型。我最初打算在Java中直接使用Keras模型,但是Keras2.0似乎还没有得到很好的支持。因此,我将Keras模型(存储在.h5中)转换为tensorflow模型(存储在.pb中)。现在我想在Java代码中使用这个模型。但是,我需要3个字符串才能成功执行此操作:

  1. “标识要加载的特定metagraphdef的标记”
  2. 将数据送入网络的操作
  3. 获取网络结果的操作

我几乎不知道怎么找到这些弦。在这一点上,我不能修改我的模型太多,特别是因为TensorFlow2.0删除了get_session(),这意味着我需要使用TensorFlow1.0,它在从Keras2.0加载模型时不断给我带来错误。我能够列出我的模型的所有操作,但我不知道近100个操作中哪一个是正确的。我也不知道metagraphdef的标签

我如何找到这3条信息


Tags: 数据字符串代码标记模型程序网络get
1条回答
网友
1楼 · 发布于 2024-07-08 08:07:13

如果使用pip(或类似conda等)安装TensorFlow,那么它应该附带saved_model_cli实用程序

您可以使用它从导出的模型中获得一些见解:

saved_model_cli show  dir <model_dir>  tag_set <tag>  signature_def <signature>

guide中查找更多信息

这是我的一个模型的结果:

The given SavedModel SignatureDef contains the following input(s):
  inputs['float32_Input'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 118)
      name: serving_default_float32_Input:0
  inputs['uint8_Input'] tensor_info:
      dtype: DT_UINT8
      shape: (-1, 583)
      name: serving_default_uint8_Input:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['tf_op_layer_ExpandDims'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 1)
      name: PartitionedCall:0
Method name is: tensorflow/serving/predict

float32_Inputuint8_InputExpandDims是Python中我的层的名称。要在Java中使用它,我必须使用以下名称:serving_default_float32_Inputserving_default_float32_InputPartitionedCall

相关问题 更多 >

    热门问题