4.训练模型之准备训练数据

终于要开始训练识别熊猫的模型了, 第一步是准备好训练数据,这里有三件事情要做:

  • 收集一定数量的熊猫图片。
  • 将图片中的熊猫用矩形框标注出来。
  • 将原始图片和标注文件转换为TFRecord格式的文件。

数据标注

收集熊猫的图片和标注熊猫位置的工作称之为“Data Labeling”,这可能是整个机器学习领域内最低级、最机械枯燥的工作了,有时候大量的 Data Labeling 工作会外包给专门的 Data Labeling 公司做, 以加快速度和降低成本。 当然我们不会把这个工作外包给别人,要从最底层的工作开始!收集熊猫图片倒不是太难,从谷歌和百度图片上收集 200 张熊猫的图片,应该足够训练一个可用的识别模型了。然后需要一些工具来做标注,我使用的是 Mac 版的 RectLabel,常用的还有 LabelImg 和 LabelMe 等。

RectLabel 标注时的界面大概是这样的:

当我们标注完成的时候,它会在 annotations 目录下生产和图片文件名相同的后缀名为 .json 的标注文件。

打开一个标注文件,其内容大概是这样的:

    {
      "filename" : "61.jpg",
      "folder" : "panda_images",
      "image_w_h" : [
        453,
        340
      ],
      "objects" : [
        {
          "label" : "panda",
          "x_y_w_h" : [
            90,
            104,
            364,
            233
          ]
        }
      ]
    }
  • image_w_h:图片的宽和高。
  • objects:图片的中的物体信息、数组。
  • label:在标注的时候指定的物体名称。
  • x_y_w_h:物体位置的矩形框:(xmin、ymin、width、height)。

接下来要做的是耐心的在这 200 张图片上面标出熊猫的位置,这个稍微要花点时间,可以在 这里 找已经标注好的图片数据。

生成 TFRecord

接下来需要一点 Python 代码来将图片和标注文件生成为 TFRecord 文件,TFRecord 文件是由很多tf.train.Example对象序列化以后组成的,先写由一个单独的图片文件生成tf.train.Example对象的函数:

    def create_sample(image_filename, data_dir):
        image_path = os.path.join(data_dir, image_filename)
        annotation_path = os.path.join(data_dir, 'annotations', os.path.splitext(image_filename)[0] + ".json")
        with tf.gfile.GFile(image_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        with open(annotation_path) as fid:
            image_annotation = json.load(fid)
        width = image_annotation['image_w_h'][0]
        height = image_annotation['image_w_h'][1]
        xmins = []
        ymins = []
        xmaxs = []
        ymaxs = []
        classes = []
        classes_text = []
        for obj in image_annotation['objects']:
            classes.append(1)
            classes_text.append('panda')
            box = obj['x_y_w_h']
            xmins.append(float(box[0]) / width)
            ymins.append(float(box[1]) / height)
            xmaxs.append(float(box[0] + box[2] - 1) / width)
            ymaxs.append(float(box[1] + box[3] - 1) / height)
        filename = image_annotation['filename'].encode('utf8')
        tf_example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': dataset_util.int64_feature(height),
            'image/width': dataset_util.int64_feature(width),
            'image/filename': dataset_util.bytes_feature(filename),
            'image/source_id': dataset_util.bytes_feature(filename),
            'image/encoded': dataset_util.bytes_feature(encoded_jpg),
            'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
            'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
            'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
            'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
            'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
            'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
            'image/object/class/label': dataset_util.int64_list_feature(classes),
        }))
        return tf_example

在这里简单说明一下:

  • 通过图片文件名找到对应的标注文件,并读入标注信息。
  • 因为图片中标注的物体都是熊猫,用数字 1 来代表,所以 class 数组里的元素值都为 1,class_text数组的里的元素值都为‘panda’。
  • Object Detection API 里面接受的矩形框输入格式为 (xmin, ymin, xmax, ymax) 和标注文件的 (xmin, ymin, width, height) 不一样,所以要做一下转换。同时需要将这些值归一化:将数值投影到 (0, 1] 的区间内。
  • 将特征组成{特征名:特征值}的 dict 作为参数来创建tf.train.Example。

接下来将tf.train.Example对象序列化,我们写一个可以由图片文件列表生成对应 TFRecord 文件的的函数:

 def create_tf_record(example_file_list, data_dir, output_file_path):
        writer = tf.python_io.TFRecordWriter(output_file_path)
        for filename in example_file_list:
            tf_example = create_sample(filename, data_dir)
            writer.write(tf_example.SerializeToString())
        writer.close()

依次调用create_sample函数然后将生成的tf.train.Example对象依次序列化即可。

最后需要将数据集切分为训练集合测试集,将图片文件打乱,然后按照 7:3 的比例进行切分:

    random.seed(42)
    random.shuffle(all_examples)
    num_examples = len(all_examples)
    num_train = int(0.7 * num_examples)
    train_examples = all_examples[:num_train]
    val_examples = all_examples[num_train:]
    create_tf_record(train_examples, data_dir, os.path.join(output_dir, 'train.record'))
    create_tf_record(val_examples, data_dir, os.path.join(output_dir, 'val.record'))

写完这个脚本以后,最好再写一个测试用例来验证这个脚本,因为我们将会花很长的时间来训练,到时候再发现脚本有 bug 就太浪费时间了,我们主要测试create_sample方法有没有根据输入数据生成正确的tf.train.Example对象:

    def test_dict_to_tf_example(self):
        image_file = '61.jpg'
        data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data')
        example = create_sample(image_file, data_dir)
        self._assertProtoEqual(
            example.features.feature['image/height'].int64_list.value, [340])
        self._assertProtoEqual(
            example.features.feature['image/width'].int64_list.value, [453])
        self._assertProtoEqual(
            example.features.feature['image/filename'].bytes_list.value,
            [image_file])
        self._assertProtoEqual(
            example.features.feature['image/source_id'].bytes_list.value,
            [image_file])
        self._assertProtoEqual(
            example.features.feature['image/format'].bytes_list.value, ['jpeg'])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmin'].float_list.value,
            [90.0 / 453])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymin'].float_list.value,
            [104.0/340])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/xmax'].float_list.value,
            [1.0])
        self._assertProtoEqual(
            example.features.feature['image/object/bbox/ymax'].float_list.value,
            [336.0/340])
        self._assertProtoEqual(
            example.features.feature['image/object/class/text'].bytes_list.value,
            ['panda'])
        self._assertProtoEqual(
            example.features.feature['image/object/class/label'].int64_list.value,
            [1])

后台回复“准备训练数据”关键字可以获取全部源码。

完成之后运行脚本,传入图片和标注的文件夹路径和输出文件路径:

python create_tf_record.py --image_dir=PATH_OF_IMAGE_SET --output_dir=OUTPUT_DIR

执行完成后会在由output_dir参数指定的目录生成train.record和val.record文件, 分别为训练集和测试集。

生成 label map 文件

最后还需要一个 label map 文件,很简单,因为我们只有一种物体:熊猫

label_map.pbtxt:
    item {
      id: 1
      name: 'panda'
    }

训练一个熊猫识别模型所需要的训练数据就准备完了,接下来开始在 GPU 主机上面开始训练。

本文分享自微信公众号 - 人人都是极客(rrgeek)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-03-06

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏本立2道生

Caffe源码理解1:Blob存储结构与设计

在逻辑上,Blob是个\(N_d\)维张量。当\(N_d=4\)时,Blob的shape定义为\(N * C * H * W\),即\(Num * Channe...

31820
来自专栏信数据得永生

《Scikit-Learn与TensorFlow机器学习实用指南》第9章 启动并运行TensorFlow

619110
来自专栏机器之心

入门 | GPU是如何优化运行机器学习算法的?

422140
来自专栏机器学习AI算法工程

用python实现决策树ID3算法,对隐形眼镜类型预测

本节讲解如何预测患者需要佩戴的隐形眼镜类型。 1、使用决策树预测隐形眼镜类型的一般流程 (1)收集数据:提供的文本文件(数据来源于UCI数据库) (2)准备数据...

47570
来自专栏null的专栏

挑战数据结构和算法面试题——最大间隔

题目来自伯乐在线,欢迎有不同答案的同学来一起讨论。 ? 分析: 本题首先需要理解清楚最大间隔的最小: 最初的间隔为:[1,1,4,1],此时最大间隔为4 删...

32030
来自专栏SeanCheney的专栏

《利用Python进行数据分析·第2版》第9章 绘图和可视化9.1 matplotlib API入门9.2 使用pandas和seaborn绘图9.3 其它的Python可视化工具9.4 总结

信息可视化(也叫绘图)是数据分析中最重要的工作之一。它可能是探索过程的一部分,例如,帮助我们找出异常值、必要的数据转换、得出有关模型的idea等。另外,做一个可...

1.2K90
来自专栏C#

开源免费的.NET图像即时处理的组件ImageProcessor

   承接以前的组件系列,这个组件系列旨在介绍.NET相关的组件,让大家可以在项目中有一个更好的选择组件的介绍绝对不是一篇文章可以叙述完的,因为一个组件是经过开...

34480
来自专栏SeanCheney的专栏

《Scikit-Learn与TensorFlow机器学习实用指南》 第09章 启动并运行TensorFlow

TensorFlow 是一款用于数值计算的强大的开源软件库,特别适用于大规模机器学习的微调。 它的基本原理很简单:首先在 Python 中定义要执行的计算图(例...

15730
来自专栏漫漫深度学习路

pytorch学习笔记(十三):backward过程的底层实现解析

博主水平有限,如有错误,请不吝指出。 pytorch源码注释,欢迎 pr,提 issue 和 star 当我们使用 pytorch 的 python 的接口编写...

746100
来自专栏ATYUN订阅号

Tensorflow 1.3.0版本的变更概述

尽管距离Tensoflow 1.2.1版本发布才仅仅一个月,但是1.3.0版本中的软件已经发生了很多变化。开发人员可以在Tensorflow的Github页面上...

38270

扫码关注云+社区

领取腾讯云代金券