Tensorflow:导入预训练模型(mobilenet、.pb、.ckpt)

2024-10-02 22:29:55 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直在研究在tensorflow中导入一个预训练模型的检查点。 这样做的目的是为了检查它的结构,并将其用于图像 分类。在

具体来说,mobilenet模型found here。我找不到 从各种*.ckpt.*文件导入模型的合理方法,并使用 在一些论坛上我发现了一个Github用户StanislawAntol写的要点 目的是将上述文件转换成一个冻结的模型,ProtoBuf(.pb)文件。这个 要点是here

运行这个脚本会给我一堆.pb文件,我希望我可以使用这些文件 有。事实上,{a3}似乎回答了我的祈祷。在

我一直在尝试以下代码的变体,但没有用。任何物体 由tf.import_graph_def返回的似乎是None类型。在

import tensorflow as tf
from tensorflow.python.platform import gfile

model_filename = LOCATION_OF_PB_FILE

with gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name='')

print(g_in)

我有什么遗漏吗?到.pb的整个转换是否错误?在


Tags: 文件模型import目的modelheretftensorflow
1条回答
网友
1楼 · 发布于 2024-10-02 22:29:55

tf.import_graph_def不返回图形,它填充范围中的“默认图形”。有关返回值的详细信息,请参见documentation for ^{}。在

在您的例子中,您可以使用tf.get_default_graph()检查图形。例如:

with gfile.FastGFile(model_filename, 'rb') as f:
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')

g = tf.get_default_graph()
print(len(g.get_operations()))

请参阅documentation for ^{}以获取有关“默认图”概念和范围界定的更多详细信息。在

希望有帮助。在

相关问题 更多 >