Tensorflow pb文件使用--sess.run()无法停止

2024-09-28 16:56:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我的英语不好,但我会尽力描述我的问题。我希望有人能帮助我。 我将基于tensorflow的对象检测模型转换为“.pb”文件,并尝试使用它。但是在我加载pb文件并运行sess.run(sth)之后,它无法停止,并且在调试/运行窗口中没有显示任何消息

操作系统:linux ubuntu 16.04 python:2.7 tensorflow:1.5.0 gpu版本 IDE:pycharm GPU:Tesla P100-PCIE-16GB 我尝试过改变几种不同的输入张量,但所有情况都是一样的。我运行了“nvidia smi”,发现它正常使用GPU。我还尝试了tfdebug工具,在我选择“运行”之后,它也无法停止,并且没有消息给我

#!/usr/bin/python
# -*-coding=utf-8-*-

import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.python import debug as tf_debug

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
PIXEL_MEAN = [123.68, 116.779, 103.939]

def SaveWeightsFile():
    saver = tf.train.import_meta_graph("./output/trained_weights/RRPN_20180901_DOTA_v1/voc_98001model.ckpt.meta",clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    sess = tf.Session()
    saver.restore(sess,
        "./output/trained_weights/RRPN_20180901_DOTA_v1/voc_98001model.ckpt")
    for node in input_graph_def.node:
        # if "det" in node.name:
        #     print(node.name)
        # if "img" in node.name:
        #     print(node.name)
        if "resnet_v1_101" in node.name:
            print(node.name)
        print(node.name)
    output_node_names = "get_batch/ResizeBilinear,postprocess_fastrcnn/Gather_14,postprocess_fastrcnn/Gather_15,postprocess_fastrcnn/Gather_16"
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(",")
    )
    output_graph = "./output_graph/saved_weights.pb"
    with tf.gfile.GFile(output_graph, 'wb') as f:
        f.write(output_graph_def.SerializeToString())
    sess.close()
    print("Done.")


def LoadAndUse():
    frozen_graph = "./trained_weights/output_graph/saved_weights.pb"
    with tf.gfile.GFile(frozen_graph, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name=""
        )
    input_tensor = graph.get_tensor_by_name("get_batch/Cast_3:0")
    det_boxes = graph.get_tensor_by_name("postprocess_fastrcnn/Gather_14:0")
    det_scores = graph.get_tensor_by_name("postprocess_fastrcnn/Gather_15:0")
    det_category =graph.get_tensor_by_name("postprocess_fastrcnn/Gather_16:0")
    img_test=cv2.imread("./images/12.png")
    sess = tf.Session(graph=graph)
    result = sess.run([det_boxes, det_scores, det_category],
             feed_dict={input_tensor: img_test})## can not stop!
    print("Done.")

if __name__ == '__main__':
    SaveWeightsFile()
    LoadAndUse()
    print("Done")

通常,它可以运行graph并给我“det_Box、det_scores、det_category”三个值,但“sess.run()”没有停止,也没有错误消息,它只是始终运行


Tags: nameimportnodeinputoutputgettfdef