理解tflite解释器输出并将其映射到类的概率

2024-03-29 12:58:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经使用Google预先训练过的ssd模型训练过ssd模型(用于对象检测),并将其转换为tflite。我对它进行了10节课的培训,并将其转换为tflite。 下面是我用来调用转换后的tflite模型来检查结果的代码

import tensorflow as tf
MODEL_PATH = 'tflite_model_path'
IMAGE_PATH = 'image of .jpeg or .png format'

interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()
img = cv2.imread(IMAGE_PATH)  
image_np = np.array(img)
resized_image = tf.image.resize(image_np, [320, 320])
input_data = tf.convert_to_tensor(np.expand_dims(resized_image, 0), dtype=tf.uint8)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

print(output_data)
[[0.05507272 0.6531384  0.94264597 1.0628431 ]
[0.09443301 0.23600875 0.86310023 0.59367293]
[0.28634408 0.27470273 0.8326082  0.4465307 ]
[0.04376397 0.27534395 0.92427313 0.9200937 ]
[0.00423892 0.824869   0.09153695 0.9980754 ]
[0.63915586 0.6903409  0.9311851  0.97774005]
[0.11331517 0.25821632 0.67732155 0.4566245 ]
[0.4935118  0.27333832 0.82703865 0.4118209 ]
[0.04359788 0.68944013 0.39454672 1.0057622 ]
[0.3145248  0.13302818 1.0334518  0.9483361 ]]]

现在,当使用tflite解释器调用我的tflite模型时,有几件事我不清楚:

  1. 什么output_details正在返回
  2. output_details的形状是tf.Tensor([ 1 10 4], shape=(3,), dtype=int32)。这个形状和这些数字代表什么
  3. 如何将此输出转换为对应于每个类的prob

Tags: path模型imageinputoutputdatamodeltf