8.SSD目标检测之二:制作自己的训练集

最近秋色甚好,一场大风刮散了雾霾,难得几天的好天气,周末回家在大巴上看着高速两旁夕阳照射下黄澄澄的树叶,晕车好像也好了很多。 特地周六赶回来为了周末去拍点素材,周日天气也好,去了陕师大拍了照片和视频。 说正经的,如何来制作数据集。

1.采集照片。

这个不用说,首先是要找照片,如果要训练自己的模型的话,数据采集这里也基本是要亲力亲为的,我自己是想检测无人机,所以百度搜了一部分图片,自己把无人机飞起来然后用相机再拍了一些,去掉一些重复的,最终150张照片。 单反的分辨率已经调到最低但是还是有3000 * 2000,而且无人机飞的较高的话我焦距有限,拍到的照片无人机占比很小。 所以我对照片进行了重新裁剪,这一部分是用lightroom来做的,结束之后全部导出,大小限制在1m。 然后对照片进行重命名,这部分后来发现是不用做的,图片命名为任意名称其实都是可以的,不过为了和VOC2007的数据集保持一致,还是做了重新命名,规则是六位数,最后面是序号,前面不够的话补零。 这个在python里面做的话就比较简单了,用5个零的字符串00000加上索引index,然后最后取末六个字符就可以了。 简单代码:

import os
import cv2
import time
import matplotlib.pyplot as plt
#原图路径和保存图片的路径
imgPath="C:\\Users\\zhxing\\Desktop\\VOCtrainval_06-Nov-2007\\VOCdevkit\\MyDate\\JPEGImages\\img\\"
savePath="C:\\Users\\zhxing\\Desktop\\VOCtrainval_06-Nov-2007\\VOCdevkit\\MyDate\\JPEGImages\\"
imgList=os.listdir(imgPath)

for i in range(1,len(imgList)):
    img=cv2.imread(imgPath+imgList[i])
    str_tmp="000000"+str(i)
    cv2.imwrite(savePath+str_tmp[-6:]+".jpg",img)      #后六位命名
print("done!!")

2.标记照片。

标记的话用软件:LabelImg。 链接:https://pan.baidu.com/s/15Tkwstfumzq8gn5Jb3Vj1Q 提取码:y1d2 使用方法也比较简单,首先在data文件夹下的txt文件下写上所有类别的名称,用英文。 然后打开软件,对每一张照片进行画框,贴标签,保存xml操作。

结合快捷键其实很快:

A: prev image
D: next image
W:creat rectbox
ctrl+s: save xml

图像中有几个目标就标定几个目标,每个目标标签都需要指定,我的类别只有一类所以标记起来挺快的,大概一个小时左右就标记完成了。

3.用xml文件来生成.tfrecord文件。

这个是必须的,tensorflow版本的SSD代码需要使用 .tfrecord文件来做为训练文件(如果是自己写模型的话用矩阵也是可以的)。 需要提前新建tfrecords_文件夹 代码: 需要改的地方主要是各个文件夹以及每个 .tfrecord文件包含xml文件的个数,这个自己设置就好了,跑的非常之快,几秒钟就完事。

import os
import sys
import random
import numpy as np
import tensorflow as tf
import xml.etree.ElementTree as ET  # 操作xml文件



#labels
VOC_LABELS = {
    'none': (0, 'Background'),
    'DJI': (1, 'Product')
}

#标签和图片所在的文件夹
DIRECTORY_ANNOTATIONS = "Annotations\\"
DIRECTORY_IMAGES = "JPEGImages\\"

# 随机种子.
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 10  # 每个.tfrecords文件包含几个.xml样本


def int64_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def float_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def bytes_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

# 图片处理
def _process_image(directory, name):
    #读取照片
    filename = directory + DIRECTORY_IMAGES + name + '.jpg'
    image_data = tf.gfile.FastGFile(filename, 'rb').read()
    #读取xml文件
    filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
    tree = ET.parse(filename)
    root = tree.getroot()

    size = root.find('size')
    shape = [int(size.find('height').text),
             int(size.find('width').text),
             int(size.find('depth').text)]
    bboxes = []
    labels = []
    labels_text = []
    difficult = []
    truncated = []

    for obj in root.findall('object'):
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0]))
        labels_text.append(label.encode('ascii'))  # 变为ascii格式

        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)

        bbox = obj.find('bndbox')
        a = float(bbox.find('ymin').text) / shape[0]
        b = float(bbox.find('xmin').text) / shape[1]
        a1 = float(bbox.find('ymax').text) / shape[0]
        b1 = float(bbox.find('xmax').text) / shape[1]
        a_e = a1 - a

        b_e = b1 - b
        if abs(a_e) < 1 and abs(b_e) < 1:
            bboxes.append((a, b, a1, b1))

    return image_data, shape, bboxes, labels, labels_text, difficult, truncated

# 转化样例
def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
                        difficult, truncated):
    xmin = []
    ymin = []
    xmax = []
    ymax = []

    for b in bboxes:
        assert len(b) == 4
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]

    image_format = b'JPEG'
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': int64_feature(shape[0]),
        'image/width': int64_feature(shape[1]),
        'image/channels': int64_feature(shape[2]),
        'image/shape': int64_feature(shape),
        'image/object/bbox/xmin': float_feature(xmin),
        'image/object/bbox/xmax': float_feature(xmax),
        'image/object/bbox/ymin': float_feature(ymin),
        'image/object/bbox/ymax': float_feature(ymax),
        'image/object/bbox/label': int64_feature(labels),
        'image/object/bbox/label_text': bytes_feature(labels_text),
        'image/object/bbox/difficult': int64_feature(difficult),
        'image/object/bbox/truncated': int64_feature(truncated),
        'image/format': bytes_feature(image_format),
        'image/encoded': bytes_feature(image_data)}))

    return example


def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
    image_data, shape, bboxes, labels, labels_text, difficult, truncated = \
        _process_image(dataset_dir, name)
    example = _convert_to_example(image_data, labels, labels_text,
                                  bboxes, shape, difficult, truncated)
    tfrecord_writer.write(example.SerializeToString())

def _get_output_filename(output_dir, name, idx):
    return '%s/%s_%03d.tfrecord' % (output_dir, name, idx)

def run(dataset_dir, output_dir, name='voc_2007_train', shuffling=False):
    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)

    path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)
    filenames = sorted(os.listdir(path))  # 排序

    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames)

    i = 0
    fidx = 0
    while i < len(filenames):
        # Open new TFRecord file.
        tf_filename = _get_output_filename(output_dir, name, fidx)
        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
            j = 0
            while i < len(filenames) and j < SAMPLES_PER_FILES:
                sys.stdout.write(' Converting image %d/%d \n' % (i + 1, len(filenames)))  # 终端打印,类似print
                sys.stdout.flush()  # 缓冲
                filename = filenames[i]
                img_name = filename[:-4]
                _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer)
                i += 1
                j += 1
            fidx += 1
    print('\nFinished converting the Pascal VOC dataset!')

# 原数据集路径,输出路径以及输出文件名,要根据自己实际做改动
dataset_dir = "C:\\Users\\zhxing\\Desktop\\VOCtrainval_06-Nov-2007\\VOCdevkit\\MyDate\\"
output_dir = "./tfrecords_"
name = "voc_train"

def main(_):
    run(dataset_dir, output_dir, name)

if __name__ == '__main__':
    tf.app.run()

大概生成这样的文件就可以了:

下面就是训练了,不知道能有什么结果!!

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Java与Android技术栈

利用tess-two和cv4j实现简单的ocr功能、

Tesseract是Ray Smith于1985到1995年间在惠普布里斯托实验室开发的一个OCR引擎,曾经在1995 UNLV精确度测试中名列前茅。但1996...

23110
来自专栏GIS讲堂

ArcGIS Image Server简介以及OL2中的加载

本文讲述Arcgis Image Server相关以及在OL2中如何加载Arcgis Server发布的影像服务。

13120
来自专栏图形学与OpenGL

3.6.2 编程实例-河南地图绘制

#include <iostream> #include <fstream> #include<vector> #include <GL/glut.h> usi...

14010
来自专栏小樱的经验随笔

HDU 1874 畅通工程续【Floyd算法实现】

畅通工程续 Time Limit: 3000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Jav...

296100
来自专栏计算机视觉与深度学习基础

Leetcode 5 Longest Palindromic Substring

Given a string S, find the longest palindromic substring in S. You may assume ...

20150
来自专栏机器之心

教程 | 如何将模型部署到安卓移动端,这里有一份简单教程

截至 2018 年,全球活跃的安卓设备已经超过了 20 亿部。安卓手机的迅速普及在很大程度上得益于各种各样的智能应用,从地图到图片编辑器无所不有。随着深度学习技...

41110
来自专栏专知

关于写论文说来简单但做起来难的三条建议

A few years ago, we prepared a series of workshops on writing research papers an...

29750
来自专栏HansBug's Lab

算法模板——单个值欧拉函数

输入N,输出phi(N) 这样的单个值欧拉函数程序一般见于部分数论题,以及有时候求逆元且取模的数不是质数的情况(逆元:A/B=A*Bphi(p)-1 (mod ...

36050
来自专栏落影的专栏

OpenGL ES实践教程(二)摄像头采集数据和渲染

教程 这一篇教程是摄像头采集数据和渲染,包括了三部分内容,渲染部分-OpenGL ES,摄像头采集图像部分-AVFoundation和图像数据创建纹理部分-G...

49450
来自专栏生信宝典

SOM基因表达聚类分析初探

上周的暑期生信黑马培训有老师提出要做SOM分析,最后卡在code plot只能出segment plot却出不来line plot。查了下,没看到解决方案。今天...

20720

扫码关注云+社区

领取腾讯云代金券