用于平衡TensorFlow对象检测API中数据的类权重

2024-06-18 17:18:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我在微调SSD对象检测器,使用TensorFlow object detection API上的Open Images dataset。我的培训数据包含不平衡的课程,例如

  1. 顶部(5K图像)
  2. 连衣裙(5万张图片)
  3. 等等。。。在

我想在分类损失中添加类权重以提高性能。我该怎么做?配置文件的以下部分似乎与此相关:

loss {
  classification_loss {
    weighted_sigmoid {
    }
  }
  localization_loss {
    weighted_smooth_l1 {
    }
  }
 ...
  classification_weight: 1.0
  localization_weight: 1.0
}

如何更改配置文件以添加每个类的分类损失权重?如果不是通过配置文件,推荐的方法是什么?在


Tags: 对象objecttensorflow配置文件分类检测器权重classification
2条回答

对象检测API丢失的定义如下:https://github.com/tensorflow/models/blob/master/research/object_detection/core/losses.py

具体而言,已实施以下损失类别:

分类损失:

  1. 加权SigmoidClassificationLoss
  2. 乙状体局部分类丢失
  3. 加权SoftMaxClassificationLoss
  4. 加权SoftMaxClassifications损失
  5. 自举SigmoidClassificationLoss

本地化损失:

  1. 加权局部损失
  2. 加权平滑1定位损失
  3. 失重

权重参数用于平衡锚(先前的框),除了硬负挖掘之外,其大小为[batch_size, num_anchors]。或者,focal loss向下权衡分类良好的示例,并将重点放在硬示例上。在

主类不平衡是由于更多的负示例(没有感兴趣对象的边界框)而不是很少的正面示例(带有对象类的边界框)。这似乎就是为什么正例子中的类不平衡(即正类标签的不均匀分布)没有作为对象检测损失的一部分实现的原因。在

API希望每个对象(bbox)的权重直接位于注释文件中。由于这一要求,使用类权重的解决方案似乎是:

1)如果您有一个自定义数据集,您可以修改每个对象(bbox)的注释,以将权重字段包含为“object/weight”。在

2)如果您不想修改注释,您可以只重新创建tf_records文件,以便包含bboxes的权重。在

3)修改API的代码(我觉得很棘手)

我决定使用#2,所以我把代码放在这里,为一个自定义数据集生成这样的加权tf记录文件,该数据集有两个类(“top”,“dress”),权重(1.0,0.1),给定一个xml注释的文件夹,如下所示:

import os
import io
import glob
import hashlib
import pandas as pd
import xml.etree.ElementTree as ET
import tensorflow as tf
import random
from PIL import Image
from object_detection.utils import dataset_util

# Define the class names and their weight
class_names = ['top', 'dress', ...]
class_weights = [1.0, 0.1, ...]

def create_example(xml_file):

        tree = ET.parse(xml_file)
        root = tree.getroot()
        image_name = root.find('filename').text
        image_path = root.find('path').text
        file_name = image_name.encode('utf8')
        size=root.find('size')
        width = int(size[0].text)
        height = int(size[1].text)
        xmin = []
        ymin = []
        xmax = []
        ymax = []
        classes = []
        classes_text = []
        truncated = []
        poses = []
        difficult_obj = []
        weights = [] # Important line

        for member in root.findall('object'):

           xmin.append(float(member[4][0].text) / width)
           ymin.append(float(member[4][1].text) / height)
           xmax.append(float(member[4][2].text) / width)
           ymax.append(float(member[4][3].text) / height)
           difficult_obj.append(0)

           class_name = member[0].text
           class_id = class_names.index(class_name)
           weights.append(class_weights[class_id])

           if class_name == 'top':
               classes_text.append('top'.encode('utf8'))
               classes.append(1)
           elif class_name == 'dress':
               classes_text.append('dress'.encode('utf8'))
               classes.append(2)
           else:
               print('E: class not recognized!')

           truncated.append(0)
           poses.append('Unspecified'.encode('utf8'))

        full_path = image_path 
        with tf.gfile.GFile(full_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        image = Image.open(encoded_jpg_io)
        if image.format != 'JPEG':
           raise ValueError('Image format not JPEG')
        key = hashlib.sha256(encoded_jpg).hexdigest()

        #create TFRecord Example
        example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': dataset_util.int64_feature(height),
            'image/width': dataset_util.int64_feature(width),
            'image/filename': dataset_util.bytes_feature(file_name),
            'image/source_id': dataset_util.bytes_feature(file_name),
            'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
            'image/encoded': dataset_util.bytes_feature(encoded_jpg),
            'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
            'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
            'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
            'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
            'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
            'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
            'image/object/class/label': dataset_util.int64_list_feature(classes),
            'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
            'image/object/truncated': dataset_util.int64_list_feature(truncated),
            'image/object/view': dataset_util.bytes_list_feature(poses),
            'image/object/weight': dataset_util.float_list_feature(weights) # Important line
        })) 
        return example  

def main(_):

    weighted_tf_records_output = 'name_of_records_file.record' # output file
    annotations_path = '/path/to/annotations/folder/*.xml' # input annotations

    writer_train = tf.python_io.TFRecordWriter(weighted_tf_records_output)
    filename_list=tf.train.match_filenames_once(annotations_path)
    init = (tf.global_variables_initializer(), tf.local_variables_initializer())
    sess=tf.Session()
    sess.run(init)
    list = sess.run(filename_list)
    random.shuffle(list)  

    for xml_file in list:
      print('-> Processing {}'.format(xml_file))
      example = create_example(xml_file)
      writer_train.write(example.SerializeToString())

    writer_train.close()
    print('-> Successfully converted dataset to TFRecord.')


if __name__ == '__main__':
    tf.app.run()

如果您有其他类型的注释,代码将非常相似,但不幸的是,这一个将不起作用。在

相关问题 更多 >