为定义输入和输出张量tf.lite.TocoCon公司

2024-10-01 22:25:48 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图将这个CPM-TF模型转换为TFLite,但是要使用TocoConverter,我需要指定输入和输出张量。 https://github.com/timctho/convolutional-pose-machines-tensorflow

我运行了包含的run_freeze_model.py,得到了cpm_hand_frozen.pb(GraphDef?)文件。在

从这篇文章中,我复制了转换ProtoBuf文件和已知输入和输出的代码片段。但是,通过查看模型定义代码,我在找到输入和输出的正确答案时遇到了一些困难。 Tensorflow Convert pb file to TFLITE using python

import tensorflow as tf
import numpy as np
from config import FLAGS

path_to_frozen_graphdef_pb = 'frozen_models/cpm_hand_frozen.pb'


def main(argv):
    input_tensors = [1, FLAGS.input_size, FLAGS.input_size, 3]
    output_tensors = np.zeros(FLAGS.num_of_joints)
    frozen_graph_def = tf.GraphDef()

    with open(path_to_frozen_graphdef_pb, 'rb') as f:
        frozen_graph_def.ParseFromString(f.read())

    tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)

if __name__ == '__main__':
    tf.app.run()

我对Tensorflow很陌生,但我认为输入应该定义为
[1, FLAGS.input_size, FLAGS.input_size, 3] 在这里发现:https://github.com/timctho/convolutional-pose-machines-tensorflow/blob/master/models/nets/cpm_hand.py#L23

不确定1代表什么,但没有一个不起作用,我想其他参数是图像大小和颜色通道。在

但是,对于该输入,它将返回一个错误: AttributeError: 'int' object has no attribute 'dtype'

我不知道输出应该是什么,除了它应该是一个数组。在


更新1

查看TF文档,我似乎需要将输入定义为张量(很明显!)。 https://www.tensorflow.org/lite/convert/python_api

input_tensors = tf.placeholder(name="img", dtype=tf.float32, shape=(1,FLAGS.input_size, FLAGS.input_size, 3))

这不会返回错误,但我仍然需要确定输入是否正确以及输出应该是什么样子。在


更新2 好吧,我终于把代码片段吐出来了 ^{pr2}$

我希望我做得对。我将尝试对Android上的模型进行推理。 我也认为在输入张量input_placeholder中存在拼写错误。它似乎在代码本身中得到了更正,但是从打印出预先训练的模型中的所有节点名之后,拼写input_placeholer就出现了。在

节点名可以在这里看到:https://github.com/timctho/convolutional-pose-machines-tensorflow/issues/59


我的设置是:

Ubuntu 18.04
CUDA 9.1和cuDNN 7.0
Python 3.6.5
Tensorflow GPU 1.6

推理就像一个魅力,所以设置本身应该没有问题。在


Tags: 代码https模型githubcominputsizetf

热门问题