【教程】利用Tensorflow目标检测API确定图像中目标的位置

深度学习提供了另一种解决“Wally在哪儿”(美国漫画)问题的方法。与传统的图像处理计算机视觉方法不同的是,它只使用了少量的标记出Wally位置的示例。

在我的Github repo上发布了具有评估图像和检测脚本的最终训练模型。

  • Github repo地址:https://github.com/tadejmagajna/HereIsWally

这篇文章描述了使用Tensorflow目标检测API来训练神经网络的过程,并使用围绕它构建的Python脚本来寻找Wally。它由以下步骤组成:

  • 通过创建一组标记训练图像来准备数据集,其中标签代表图像中Wally的xy位置;
  • 读取和配置模型以使用Tensorflow目标检测API;
  • 在我们的数据集上训练模型
  • 使用导出的图形对评估图像的模型进行测试

开始之前,请确保按照说明安装Tensorflow目标检测API。

准备数据集

神经网络是深度学习的过程中最值得注意的过程,但遗憾的是,科学家们花费大量时间的准备和格式化训练数据。

最简单的机器学习问题的目标值通常是标量(比如数字检测器)或分类字符串。Tensorflow目标检测API训练数据使用两者的结合。它包括一组图像,并附有特定目标的标签和它们在图像中出现的位置。位置用两点(二维空间)定义,两点足够画一个物体周围的包围盒。

因此,为了创建训练集,我们需要提出一组Wally出现地点的图片。

虽然我可以用LabelImg这样的注释工具,花费数周的时间通过手工标记图像来解决问题,但我发现了一个已经解决了Where’s Wally这个问题的训练集。

Wally训练数据集,最后四列描述了Wally出现在图像中的位置

准备数据集的最后一步是将我们的标签(保存为文本文件)和图像(.jpeg)打包成一个二进制.tfrecord文件(该过程的解释代码地址见段末),但可以找到训练和重新运算求出Wally位置的参数内容。 .tfecord文件在我的Github repo上。

  • Github repo地址:https://github.com/tadejmagajna/HereIsWally
  • 解释地址:http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

准备模型

Tensorflow目标检测API提供了一组经过多次公开数据集训练的具有不同性能(通常为速度 – 精度折衷)的预训练模型。

虽然模型可以从头开始随机初始化网络权值,但这个过程可能需要几周的时间。我们使用一种称为转移学习的方法来替换该过程。

转移学习包含采用通常训练的模型解决一些一般问题并且重新训练模型以解决我们的问题。转移学习的工作原理是,通过使用在预先训练的模型中获得的知识并将其转移到新的模型中,来代替从头开始训练模型这些无用的重复工作。这为我们节省了大量的时间,将花费在训练上的时间用于获得针对我们问题的知识。

我们使用带有经过COCO数据集训练的Inception v2模型的RCNN,以及它的管道配置文件。该模型包含一个检查点.ckpt文件,我们可以使用该文件开始训练。

  • RCNN地址: http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2017_11_08.tar.gz
  • 管道配置文件地址: https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_inception_v2_coco.config

下载配置文件后,请确保用指向检查点文件、训练以及评估.tfrecord文件与标签映射文件的路径代替“PATH_TO_BE_CONFIGURED”字段。

需要配置的最终文件是labels.txt映射文件,其中包含所有不同目标的标签。由于我们只是在寻找一种类型的目标,我们的标签文件看起来像这样:

item {
  id: 1
  name: 'waldo'
}

最后,我们最终应该:

  • 具有.ckpt检查点文件的预训练模型;
  • 训练和评估.tfrecord数据集;
  • 标记映射文件;
  • 指向以上文件的管道配置文件。

现在,我们准备开始训练。

训练

Tensorflow目标检测API提供了一个简单易用的Python脚本来重新训练我们的模型。它位于models / research / object_detection中,可以利用下列路径运行:

python train.py –logtostderr –pipeline_config_path= PATH_TO_PIPELINE_CONFIG –train_dir=PATH_TO_TRAIN_DIR

其中PATH_TO_PIPELINE_CONFIG是到管道配置文件的路径,PATH_TO_TRAIN_DIR是一个新创建的目录,我们的新检查点和模型将被存储在该目录中。

train.py的输出应该如下所示:

用最重要的信息来查找损失。这是在训练或验证集中每个示例错误的总和。当然,你希望它尽可能低,这意味着,缓慢下降表示你的模型正在学习(或过度拟合你的训练数据)。你还可以使用Tensorboard来更详细地显示训练数据。

该脚本将在一定数量的步骤后自动存储检查点文件,以便你随时恢复保存的检查点,以防计算机在学习过程中崩溃。

这意味着当你想结束模型的训练时,你可以终止脚本。

但是什么时候停止学习?关于何时停止训练,原则上是当评估集的损失减少或非常低时(在我们的例子中低于0.01)。

测试

现在我们可以通过在一些示例图像上进行测试来实际使用我们的模型。

首先,我们需要使用models/research/object_detection脚本中存储的检查点(位于我们的训练目录中)导出推理图:

python export_inference_graph.py — pipeline_config_path PATH_TO_PIPELINE_CONFIG --trained_checkpoint_prefix PATH_TO_CHECPOINT --output_directory OUTPUT_PATH

我们的Python脚本可以用导出的推理图来查找Wally的位置。

我写了一些简单的Python脚本(基于Tensorflow 目标检测API),你可以在模型上使用它们执行目标检测,并在检测到的目标周围绘制框或将其暴露。

find_wally.py和find_wally_pretty.py都可以在我的Github仓库中找到,可以简单地运行:

  • find_wally.py地址: https://github.com/tadejmagajna/HereIsWally/blob/master/find_wally.py
  • find_wally_pretty.py地址: https://github.com/tadejmagajna/HereIsWally/blob/master/find_wally_pretty.py
  • Github repo 地址: https://github.com/tadejmagajna/HereIsWally
python find_wally.py

或者

python find_wally_pretty.py

在自己的模型或自己的评估图像上使用脚本时,请确保修改model_path和image_path变量。

结语

在我的Github repo 上发布的模型表现非常出色。

模型设法在评估图像中找到Wally,并且对网络上的一些额外的随机例子处理得很好。它未能找到很大的Wally,直观来说,找到小的walley应该更容易解决。这表明我们的模型可能过度适合我们的训练数据,主要是因为训练图像较少。

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-12-14

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

教程 | TensorEditor :一个小白都能快速玩转的神经网络搭建工具

1706
来自专栏深度学习自然语言处理

如何用简单易懂的例子解释隐马尔可夫模型?(进阶篇)

昨天通俗易懂的讲解了什么是HMM,没看的点这里。那么今天就来看看,具体理论是什么以及数学上怎么计算的呢?

651
来自专栏人工智能LeadAI

人脸识别 | 卷积深度置信网络工具箱的使用

本文主要以ORL_64x64人脸数据库识别为例,介绍如何使用基于matlab的CDBN工具箱。至于卷积深度置信网络(CDBN,Convolutional Dee...

2785
来自专栏AI研习社

Github 代码实践:Pytorch 实现的语义分割器

使用Detectron预训练权重输出 *e2e_mask_rcnn-R-101-FPN_2x* 的示例

1232
来自专栏ATYUN订阅号

深度学习图像识别项目(中):Keras和卷积神经网络(CNN)

在下篇文章中,我还会演示如何将训练好的Keras模型,通过几行代码将其部署到智能手机上。

6415
来自专栏木子昭的博客

Tensorflow可视化编程安装Tensoflow1.0将加法运算以图形化方式展示实现简单的线性回归为程序添加作用域模型的保存与恢复(保存会话资源)

安装Tensoflow1.0 Linux/ubuntu: python2.7: pip install https://storage.googleapis.c...

2958
来自专栏AI研习社

YOLO 升级到 v3 版,速度相比 RetinaNet 快 3.8 倍

雷锋网 AI 研习社按,YOLO 是一种非常流行的目标检测算法,速度快且结构简单。日前,YOLO 作者推出 YOLOv3 版,在 Titan X 上训练时,在 ...

1063
来自专栏视觉求索无尽也

【总结经验】炼丹路上的坑与经验

703
来自专栏有趣的Python和你

Python数据分析之方差分析

设某苗圃对一花木种子制定了5种不同的处理方法,每种方法处理了6粒种子进行育苗试验。一年后观察苗高获得资料如下表。已知除处理方法不同外,其他育苗条件相同且苗高的分...

772
来自专栏机器之心

GPU捉襟见肘还想训练大批量模型?谁说不可以

2018 年的大部分时间我都在试图训练神经网络时克服 GPU 极限。无论是在含有 1.5 亿个参数的语言模型(如 OpenAI 的大型生成预训练 Transfo...

912

扫码关注云+社区