如何使TensorFlow的“import_graph_def”返回张量

2024-10-02 22:37:28 发布

您现在位置:Python中文网/ 问答频道 /正文

如果尝试导入保存的TensorFlow图形定义

import tensorflow as tf
from tensorflow.python.platform import gfile

with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
x, y, y_ = tf.import_graph_def(graph_def, 
                               return_elements=['data/inputs',
                                                'output/network_activation',
                                                'data/correct_outputs'],
                               name='')

返回的值不是预期的Tensors,而是其他一些值:例如,将xas

Tensor("data/inputs:0", shape=(?, 784), dtype=float32)

我明白了

name: "data/inputs_1"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
    }
  }
}

也就是说,我得到的不是期望的张量x,而是x.op。这让我很困惑,因为documentation似乎说我应该得到一个Tensor(尽管那里有一堆很难理解)。

如何使tf.import_graph_def返回特定的Tensors,然后我可以使用这些tf.import_graph_def(例如,在提供加载的模型或运行分析时)?


Tags: nameimportdatatftensorflowdefasgraph
1条回答
网友
1楼 · 发布于 2024-10-02 22:37:28

名称'data/inputs''output/network_activation''data/correct_outputs'实际上是操作名称。若要使tf.import_graph_def()返回tf.Tensor对象,应将输出索引附加到操作名称,对于单个输出操作,该名称通常为':0'

x, y, y_ = tf.import_graph_def(graph_def, 
                               return_elements=['data/inputs:0',
                                                'output/network_activation:0',
                                                'data/correct_outputs:0'],
                               name='')

相关问题 更多 >