量化后的Tensorflow导入图

2024-06-26 14:24:45 发布

您现在位置:Python中文网/ 问答频道 /正文

在 我尝试使用TransformGraph为一个定制的LSTM模型生成一个八位量化图。如果我只量化权重,图形导入工作正常。应用量化节点后,导入失败,错误如下所示

ValueError: Specified colocation to an op that does not exist during import: lstm1/lstm1/BasicLSTMCellZeroState/zeros in lstm1/lstm1/cond/Switch_2

在 下面列出了用于量化的代码片段

from tensorflow.tools.graph_transforms import TransformGraph
import tensorflow as tf

input_names = ["inp/X"]
output_names = ["out/Softmax"]
#transforms = ["quantize_weights", "quantize_nodes"]
#transforms = ["quantize_weights"]
transforms = ["add_default_attributes",
"strip_unused_nodes",
"remove_nodes(op=Identity, op=CheckNumerics)",
#"fold_constants(ignore_errors=true)",
"fold_batch_norms",
"fold_old_batch_norms",
"quantize_weights",
"quantize_nodes",
"sort_by_execution_order"]
#output_graph_path="/tmp/fixed.pb"
output_graph_path="/tmp/output_graph.pb"
with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with tf.Session() as sess:
            with open(output_graph_path, "rb") as f:

                output_graph_def.ParseFromString(f.read())
                _ = tf.import_graph_def(output_graph_def, name="")

                transformed_graph_def = TransformGraph(output_graph_def, input_names,
                                       output_names, transforms)

                tf.train.write_graph(transformed_graph_def, '/tmp', 'quantized.pb', as_text=False)

在 我也试过用量子化_图形.py,它总是导致键错误,如https://github.com/tensorflow/tensorflow/issues/8025。我相信这个代码已经不再被维护了。有谁能指出如何调试这个问题吗。在


Tags: importoutputnamestftensorflowdefasgraph