如果尝试导入保存的TensorFlow图形定义
import tensorflow as tf
from tensorflow.python.platform import gfile
with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
x, y, y_ = tf.import_graph_def(graph_def,
return_elements=['data/inputs',
'output/network_activation',
'data/correct_outputs'],
name='')
返回的值不是预期的Tensor
s,而是其他一些值:例如,将x
as
Tensor("data/inputs:0", shape=(?, 784), dtype=float32)
我明白了
name: "data/inputs_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
也就是说,我得到的不是期望的张量x
,而是x.op
。这让我很困惑,因为documentation似乎说我应该得到一个Tensor
(尽管那里有一堆或很难理解)。
如何使tf.import_graph_def
返回特定的Tensor
s,然后我可以使用这些tf.import_graph_def
(例如,在提供加载的模型或运行分析时)?
名称
'data/inputs'
、'output/network_activation'
和'data/correct_outputs'
实际上是操作名称。若要使tf.import_graph_def()
返回tf.Tensor
对象,应将输出索引附加到操作名称,对于单个输出操作,该名称通常为':0'
:相关问题 更多 >
编程相关推荐