前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >一种目标检测任务中图像-标注对增强方法

一种目标检测任务中图像-标注对增强方法

作者头像
用户9875047
发布2022-12-07 10:22:00
3690
发布2022-12-07 10:22:00
举报
文章被收录于专栏:机器视觉全栈er机器视觉全栈er

其实,本篇应是深度学习常用图像数据增强库albumentations系列教程(三)的,但是鉴于不如现在的题目直观,还是修改了,原来两篇见如下:

深度学习常用图像数据增强库albumentations系列教程(一)

深度学习常用图像数据增强库albumentations系列教程(二)

本篇是在前面两篇基础上,对目标检测任务中常用的包围框标注数据进行增强。

1. 目标检测任务包围框

目标检测任务中在训练之前要对图像中的目标物体进行标注,比如使用labelimg对目标物体的位置和类别进行标注,生成xml文件(数据是pascal_voc格式)。

albumentations支持四种数据格式: pascal_voc,albumentations, coco和yolo,这四种数据格式使用不同的方法表示包围框的位置。

  • • pascal_voc: 使用[x_min, y_min, x_max, y_max]描述包围框。x_min和y_min是包围框左上角的坐标,y_min和y_max是右下角的坐标,如[138, 103, 161, 471]
  • • albumentations: 使用[x_min, y_min, x_max, y_max]表示,和pascal_voc不同的是albumentations用的是归一化的值去描述,即将横纵坐标除以相应的长宽,如[138/640, 103/480, 161/640, 471/480]
  • • coco: 使用[x_min, y_min, width, height]表示包围框,如[138, 103, 23, 368]
  • • yolo: 使用[x_center, y_center, width, height],前面两个参数是规范化后的包围框的中心位置,如[((138+161)/2)/640, ((103+471)/2)/480, 23/640, 368/480]

2. 目标检测任务图像-标注对数据增强功能实现

针对训练样本量少的情况,我们常常会使用数据增强的方法增加样本量,如图像的旋转、平移、缩放、改变亮度等,针对增强后的图像常常还需要标注,标注工作量较大。尽管有些方法在训练的时候会帮你实现这些功能,我个人还是习惯将标注增强直观展示,确定标注增强的合理性。

图像-标注对增强包括如下流程:

  1. 1. 利用单张或者多张图像进行标注,生成xml文件
  2. 2. 定义增强pipeline
  3. 3. 从文件夹中遍历原始的图像文件和xml文件
  4. 4. 通过增强pipeline得到图像标注增强对用于训练

注意:不是所有的变换都支持包围框标注数据增强的,目前(20220921)支持包围框增强的变换

代码语言:javascript
复制
import random
import cv2
from matplotlib import pyplot as plt
import xml.etree.ElementTree as ET
import albumentations as A
import os
import time
import glob
from tqdm import trange

BOX_COLOR = (255, 0, 0)  # Red
TEXT_COLOR = (255, 255, 255)  # White
# original pictures size:62, then total size is 62*GENERATED_PICS_SIZE
GENERATED_PICS_SIZE = 600


def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=2):
    """Visualizes a single bounding box on the image"""
    # x_min, y_min, w, h = bbox
    # x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(
    #     y_min + h)
    x_min, y_min, x_max, y_max = bbox
    print(x_min, y_min, x_max, y_max)

    cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)),
                  color=color, thickness=thickness)

    ((text_width, text_height), _) = cv2.getTextSize(class_name,
                                                     cv2.FONT_HERSHEY_SIMPLEX,
                                                     0.35, 1)
    cv2.rectangle(img, (int(x_min), int(y_min) - int(1.3 * text_height)),
                  (int(x_min) + text_width, int(y_min)), BOX_COLOR, -1)
    cv2.putText(
        img,
        text=class_name,
        org=(int(x_min), int(y_min) - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35,
        color=TEXT_COLOR,
        lineType=cv2.LINE_AA,
    )
    return img


def visualize(image, bboxes, category_ids, category_id_to_name):
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
        class_name = category_id_to_name[category_id]
        img = visualize_bbox(img, bbox, class_name)
    plt.axis('off')
    plt.imshow(img)
    plt.show()


def saveNewAnnotation(new_xml_path, new_jpg_path, xml_path, bboxes, cur_dir):
    in_file = open(os.path.join(xml_path), encoding='utf-8')
    new_file = in_file
    tree = ET.parse(new_file)
    root = tree.getroot()
    root[0].text = "annotation_out"
    root[1].text = new_jpg_path
    root[2].text = cur_dir + '\\annotation_out\\' + new_jpg_path

    idx = 0
    for obj in root.iter('object'):
        obj[4][0].text = str(round(bboxes[idx][0]))
        obj[4][1].text = str(round(bboxes[idx][1]))
        obj[4][2].text = str(round(bboxes[idx][2]))
        obj[4][3].text = str(round(bboxes[idx][3]))
        idx += 1
    tree.write(new_xml_path, 'UTF-8')


def getAnnotation(xml_path):
    '''
    :param xml_path:
    :return: bboxes, category_ids
    '''

    in_file = open(os.path.join(xml_path), encoding='utf-8')
    try:
        tree = ET.parse(in_file)
    except:
        return [], []
    root = tree.getroot()

    bboxes = []
    category_ids = []

    for obj in root.iter('object'):
        cls = obj.find('name').text

        xmlbox = obj.find('bndbox')
        bbox = [int(float(xmlbox.find('xmin').text)),
                int(float(xmlbox.find('ymin').text)),
                int(float(xmlbox.find('xmax').text)),
                int(float(xmlbox.find('ymax').text))]
        bboxes.append(bbox)
        category_ids.append(cls)
    return bboxes, category_ids


def main(cur_dir):
    PICS_PATH = 'annotation_ori'  # 存放图片的文件夹路径
    paths = glob.glob(os.path.join(PICS_PATH, '*.jpg'))
    for i in trange(len(paths)):
        jpg_path = paths[i]
        xml_path = jpg_path.split('.')[0] + ".xml"

        # print(jpg_path.split('.'))
        new_jpg_path_prefix = 'annotation_out\\'

        for i in range(GENERATED_PICS_SIZE):
            image = cv2.imread(jpg_path)
            # print(width, ", ", height)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            new_jpg_path = jpg_path.split('.')[0].split('\\')[-1] + str(
                i + 1).zfill(4) + ".jpg"
            new_xml_path = new_jpg_path_prefix + \
                           jpg_path.split('.')[0].split('\\')[-1] + str(
                i + 1).zfill(4) + ".xml"
            bboxes, category_ids = getAnnotation(xml_path=xml_path)
            if len(bboxes) == 0 & len(category_ids) == 0:
                continue
            category_id_to_name = {}
            for i in range(len(category_ids)):
                category_id_to_name[category_ids[i]] = category_ids[i]
            # 变换操作
            # 水平反转,高斯模糊,gamma变换,亮度变化,
            transform = A.Compose(
                [
                    A.HorizontalFlip(p=0.5),
                    A.Rotate(limit=2, p=0.3),
                    A.ShiftScaleRotate(shift_limit=0.0625,scale_limit=0, rotate_limit=0,p=0.3),
                    A.GaussianBlur(blur_limit=1, p=0.5),
                    A.ColorJitter(brightness=0.05, contrast=0.05,
                                  saturation=0.02,
                                  hue=0.02, always_apply=False, p=1)
                ],
                bbox_params=A.BboxParams(format='pascal_voc',
                                         label_fields=['category_ids']),
            )
            transformed = transform(image=image, bboxes=bboxes,
                                    category_ids=category_ids)
            image = transformed['image']
            bboxes = transformed['bboxes']
            category_ids = transformed['category_ids']
            # print(bboxes)
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            cv2.imencode('.jpg', image)[1].tofile("annotation_out\\" + new_jpg_path)
            # visualize(image, bboxes, category_ids, category_id_to_name)
            saveNewAnnotation(new_xml_path, new_jpg_path, xml_path, bboxes, cur_dir)
        time.sleep(1)

if __name__ == '__main__':
    import os
    cur_dir = os.path.dirname(os.path.abspath(__file__))  # 上级目录
    print(cur_dir)
    main(cur_dir)

通过上述代码,我们会生成大量基于原始图像-标注对的衍生图像-标注对。

建议先生成少量图像-标注对,然后使用labelimg查看下生成的图像-标注对是否正确,确定无问题后,再生成大量图片-标注对

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-09-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器视觉全栈er 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 目标检测任务包围框
  • 2. 目标检测任务图像-标注对数据增强功能实现
相关产品与服务
图像识别
腾讯云图像识别基于深度学习等人工智能技术,提供车辆,物体及场景等检测和识别服务, 已上线产品子功能包含车辆识别,商品识别,宠物识别,文件封识别等,更多功能接口敬请期待。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档