我的英语不好,但我会尽力描述我的问题。我希望有人能帮助我。 我将基于tensorflow的对象检测模型转换为“.pb”文件,并尝试使用它。但是在我加载pb文件并运行sess.run(sth)之后,它无法停止,并且在调试/运行窗口中没有显示任何消息
操作系统:linux ubuntu 16.04 python:2.7 tensorflow:1.5.0 gpu版本 IDE:pycharm GPU:Tesla P100-PCIE-16GB 我尝试过改变几种不同的输入张量,但所有情况都是一样的。我运行了“nvidia smi”,发现它正常使用GPU。我还尝试了tfdebug工具,在我选择“运行”之后,它也无法停止,并且没有消息给我
#!/usr/bin/python
# -*-coding=utf-8-*-
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.python import debug as tf_debug
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
PIXEL_MEAN = [123.68, 116.779, 103.939]
def SaveWeightsFile():
saver = tf.train.import_meta_graph("./output/trained_weights/RRPN_20180901_DOTA_v1/voc_98001model.ckpt.meta",clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
sess = tf.Session()
saver.restore(sess,
"./output/trained_weights/RRPN_20180901_DOTA_v1/voc_98001model.ckpt")
for node in input_graph_def.node:
# if "det" in node.name:
# print(node.name)
# if "img" in node.name:
# print(node.name)
if "resnet_v1_101" in node.name:
print(node.name)
print(node.name)
output_node_names = "get_batch/ResizeBilinear,postprocess_fastrcnn/Gather_14,postprocess_fastrcnn/Gather_15,postprocess_fastrcnn/Gather_16"
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(",")
)
output_graph = "./output_graph/saved_weights.pb"
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
sess.close()
print("Done.")
def LoadAndUse():
frozen_graph = "./trained_weights/output_graph/saved_weights.pb"
with tf.gfile.GFile(frozen_graph, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name=""
)
input_tensor = graph.get_tensor_by_name("get_batch/Cast_3:0")
det_boxes = graph.get_tensor_by_name("postprocess_fastrcnn/Gather_14:0")
det_scores = graph.get_tensor_by_name("postprocess_fastrcnn/Gather_15:0")
det_category =graph.get_tensor_by_name("postprocess_fastrcnn/Gather_16:0")
img_test=cv2.imread("./images/12.png")
sess = tf.Session(graph=graph)
result = sess.run([det_boxes, det_scores, det_category],
feed_dict={input_tensor: img_test})## can not stop!
print("Done.")
if __name__ == '__main__':
SaveWeightsFile()
LoadAndUse()
print("Done")
通常,它可以运行graph并给我“det_Box、det_scores、det_category”三个值,但“sess.run()”没有停止,也没有错误消息,它只是始终运行
目前没有回答
相关问题 更多 >
编程相关推荐