【教程】利用Tensorflow目标检测API确定图像中目标的位置

深度学习提供了另一种解决“Wally在哪儿”(美国漫画)问题的方法。与传统的图像处理计算机视觉方法不同的是,它只使用了少量的标记出Wally位置的示例。

在我的Github repo上发布了具有评估图像和检测脚本的最终训练模型。

  • Github repo地址:https://github.com/tadejmagajna/HereIsWally

这篇文章描述了使用Tensorflow目标检测API来训练神经网络的过程,并使用围绕它构建的Python脚本来寻找Wally。它由以下步骤组成:

  • 通过创建一组标记训练图像来准备数据集,其中标签代表图像中Wally的xy位置;
  • 读取和配置模型以使用Tensorflow目标检测API;
  • 在我们的数据集上训练模型
  • 使用导出的图形对评估图像的模型进行测试

开始之前,请确保按照说明安装Tensorflow目标检测API。

准备数据集

神经网络是深度学习的过程中最值得注意的过程,但遗憾的是,科学家们花费大量时间的准备和格式化训练数据。

最简单的机器学习问题的目标值通常是标量(比如数字检测器)或分类字符串。Tensorflow目标检测API训练数据使用两者的结合。它包括一组图像,并附有特定目标的标签和它们在图像中出现的位置。位置用两点(二维空间)定义,两点足够画一个物体周围的包围盒。

因此,为了创建训练集,我们需要提出一组Wally出现地点的图片。

虽然我可以用LabelImg这样的注释工具,花费数周的时间通过手工标记图像来解决问题,但我发现了一个已经解决了Where’s Wally这个问题的训练集。

Wally训练数据集,最后四列描述了Wally出现在图像中的位置

准备数据集的最后一步是将我们的标签(保存为文本文件)和图像(.jpeg)打包成一个二进制.tfrecord文件(该过程的解释代码地址见段末),但可以找到训练和重新运算求出Wally位置的参数内容。 .tfecord文件在我的Github repo上。

  • Github repo地址:https://github.com/tadejmagajna/HereIsWally
  • 解释地址:http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

准备模型

Tensorflow目标检测API提供了一组经过多次公开数据集训练的具有不同性能(通常为速度 – 精度折衷)的预训练模型。

虽然模型可以从头开始随机初始化网络权值,但这个过程可能需要几周的时间。我们使用一种称为转移学习的方法来替换该过程。

转移学习包含采用通常训练的模型解决一些一般问题并且重新训练模型以解决我们的问题。转移学习的工作原理是,通过使用在预先训练的模型中获得的知识并将其转移到新的模型中,来代替从头开始训练模型这些无用的重复工作。这为我们节省了大量的时间,将花费在训练上的时间用于获得针对我们问题的知识。

我们使用带有经过COCO数据集训练的Inception v2模型的RCNN,以及它的管道配置文件。该模型包含一个检查点.ckpt文件,我们可以使用该文件开始训练。

  • RCNN地址: http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2017_11_08.tar.gz
  • 管道配置文件地址: https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_inception_v2_coco.config

下载配置文件后,请确保用指向检查点文件、训练以及评估.tfrecord文件与标签映射文件的路径代替“PATH_TO_BE_CONFIGURED”字段。

需要配置的最终文件是labels.txt映射文件,其中包含所有不同目标的标签。由于我们只是在寻找一种类型的目标,我们的标签文件看起来像这样:

item {
  id: 1
  name: 'waldo'
}

最后,我们最终应该:

  • 具有.ckpt检查点文件的预训练模型;
  • 训练和评估.tfrecord数据集;
  • 标记映射文件;
  • 指向以上文件的管道配置文件。

现在,我们准备开始训练。

训练

Tensorflow目标检测API提供了一个简单易用的Python脚本来重新训练我们的模型。它位于models / research / object_detection中,可以利用下列路径运行:

python train.py –logtostderr –pipeline_config_path= PATH_TO_PIPELINE_CONFIG –train_dir=PATH_TO_TRAIN_DIR

其中PATH_TO_PIPELINE_CONFIG是到管道配置文件的路径,PATH_TO_TRAIN_DIR是一个新创建的目录,我们的新检查点和模型将被存储在该目录中。

train.py的输出应该如下所示:

用最重要的信息来查找损失。这是在训练或验证集中每个示例错误的总和。当然,你希望它尽可能低,这意味着,缓慢下降表示你的模型正在学习(或过度拟合你的训练数据)。你还可以使用Tensorboard来更详细地显示训练数据。

该脚本将在一定数量的步骤后自动存储检查点文件,以便你随时恢复保存的检查点,以防计算机在学习过程中崩溃。

这意味着当你想结束模型的训练时,你可以终止脚本。

但是什么时候停止学习?关于何时停止训练,原则上是当评估集的损失减少或非常低时(在我们的例子中低于0.01)。

测试

现在我们可以通过在一些示例图像上进行测试来实际使用我们的模型。

首先,我们需要使用models/research/object_detection脚本中存储的检查点(位于我们的训练目录中)导出推理图:

python export_inference_graph.py — pipeline_config_path PATH_TO_PIPELINE_CONFIG --trained_checkpoint_prefix PATH_TO_CHECPOINT --output_directory OUTPUT_PATH

我们的Python脚本可以用导出的推理图来查找Wally的位置。

我写了一些简单的Python脚本(基于Tensorflow 目标检测API),你可以在模型上使用它们执行目标检测,并在检测到的目标周围绘制框或将其暴露。

find_wally.py和find_wally_pretty.py都可以在我的Github仓库中找到,可以简单地运行:

  • find_wally.py地址: https://github.com/tadejmagajna/HereIsWally/blob/master/find_wally.py
  • find_wally_pretty.py地址: https://github.com/tadejmagajna/HereIsWally/blob/master/find_wally_pretty.py
  • Github repo 地址: https://github.com/tadejmagajna/HereIsWally
python find_wally.py

或者

python find_wally_pretty.py

在自己的模型或自己的评估图像上使用脚本时,请确保修改model_path和image_path变量。

结语

在我的Github repo 上发布的模型表现非常出色。

模型设法在评估图像中找到Wally,并且对网络上的一些额外的随机例子处理得很好。它未能找到很大的Wally,直观来说,找到小的walley应该更容易解决。这表明我们的模型可能过度适合我们的训练数据,主要是因为训练图像较少。

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-12-14

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

Github 项目推荐 | ANSI C 的简单神经网络库

Genann是一个经过精心测试的库,用于在 C 中训练和使用前馈人工神经网络(ANN)。它的主要特点是简单、快速、可靠和可魔改(hackable),它只需要提供...

531
来自专栏Python小屋

Python+sklearn使用朴素贝叶斯算法识别中文垃圾邮件

2、读取全部训练集,删除其中的干扰字符,例如【】*。、,等等,然后分词,删除长度为1的单个字。

1235
来自专栏专知

在浏览器上也能训练神经网络?TensorFlow.js带你玩游戏~

【导读】一直以来训练神经网络给我们的印象都是复杂、耗时、对硬件要求高。你有没有想过有一天在浏览器上也能训练神经网络~ 本文通过一篇详细的TensorFlow.j...

820
来自专栏CreateAMind

dcgan人脸生成效果复现-多图及代码学习

https://github.com/carpedm20/DCGAN-tensorflow

811
来自专栏ATYUN订阅号

深度学习与R语言

对于R语言用户来说,深度学习还没有生产级的解决方案(除了MXNET)。这篇文章介绍了R语言的Keras接口,以及如何使用它来执行图像分类。文章结尾会通过提供一些...

3554
来自专栏九彩拼盘的叨叨叨

知识点模板

522
来自专栏null的专栏

优化算法——拟牛顿法之L-BFGS算法

一、BFGS算法    image.png 二、BGFS算法存在的问题    image.png 三、L-BFGS算法思路    image.png im...

3135
来自专栏专知

【干货】手把手教你用苹果Core ML和Swift开发人脸目标识别APP

【导读】CoreML是2017年苹果WWDC发布的最令人兴奋的功能之一。它可用于将机器学习整合到应用程序中,并且全部脱机。CoreML提供的机器学习 API,包...

2676
来自专栏机器之心

资源 | 微软开源MMdnn:实现多个框架之间的模型转换

选自GitHub 作者:Kit CHEN等 机器之心编译 参与:路雪、思源 近日,微软开源 MMdnn,可用于转换、可视化和诊断深度神经网络模型的全面、跨框架解...

3366
来自专栏AI研习社

Github 项目推荐 | 微软开源 MMdnn,模型可在多框架间转换

近期,微软开源了 MMdnn,这是一套能让用户在不同深度学习框架间做相互操作的工具。比如,模型的转换和可视化,并且可以让模型在 Caffe、Keras、MXNe...

3498

扫描关注云+社区