【教程】利用Tensorflow目标检测API确定图像中目标的位置

深度学习提供了另一种解决“Wally在哪儿”(美国漫画)问题的方法。与传统的图像处理计算机视觉方法不同的是,它只使用了少量的标记出Wally位置的示例。

在我的Github repo上发布了具有评估图像和检测脚本的最终训练模型。

  • Github repo地址:https://github.com/tadejmagajna/HereIsWally

这篇文章描述了使用Tensorflow目标检测API来训练神经网络的过程,并使用围绕它构建的Python脚本来寻找Wally。它由以下步骤组成:

  • 通过创建一组标记训练图像来准备数据集,其中标签代表图像中Wally的xy位置;
  • 读取和配置模型以使用Tensorflow目标检测API;
  • 在我们的数据集上训练模型
  • 使用导出的图形对评估图像的模型进行测试

开始之前,请确保按照说明安装Tensorflow目标检测API。

准备数据集

神经网络是深度学习的过程中最值得注意的过程,但遗憾的是,科学家们花费大量时间的准备和格式化训练数据。

最简单的机器学习问题的目标值通常是标量(比如数字检测器)或分类字符串。Tensorflow目标检测API训练数据使用两者的结合。它包括一组图像,并附有特定目标的标签和它们在图像中出现的位置。位置用两点(二维空间)定义,两点足够画一个物体周围的包围盒。

因此,为了创建训练集,我们需要提出一组Wally出现地点的图片。

虽然我可以用LabelImg这样的注释工具,花费数周的时间通过手工标记图像来解决问题,但我发现了一个已经解决了Where’s Wally这个问题的训练集。

Wally训练数据集,最后四列描述了Wally出现在图像中的位置

准备数据集的最后一步是将我们的标签(保存为文本文件)和图像(.jpeg)打包成一个二进制.tfrecord文件(该过程的解释代码地址见段末),但可以找到训练和重新运算求出Wally位置的参数内容。 .tfecord文件在我的Github repo上。

  • Github repo地址:https://github.com/tadejmagajna/HereIsWally
  • 解释地址:http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

准备模型

Tensorflow目标检测API提供了一组经过多次公开数据集训练的具有不同性能(通常为速度 – 精度折衷)的预训练模型。

虽然模型可以从头开始随机初始化网络权值,但这个过程可能需要几周的时间。我们使用一种称为转移学习的方法来替换该过程。

转移学习包含采用通常训练的模型解决一些一般问题并且重新训练模型以解决我们的问题。转移学习的工作原理是,通过使用在预先训练的模型中获得的知识并将其转移到新的模型中,来代替从头开始训练模型这些无用的重复工作。这为我们节省了大量的时间,将花费在训练上的时间用于获得针对我们问题的知识。

我们使用带有经过COCO数据集训练的Inception v2模型的RCNN,以及它的管道配置文件。该模型包含一个检查点.ckpt文件,我们可以使用该文件开始训练。

  • RCNN地址: http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2017_11_08.tar.gz
  • 管道配置文件地址: https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_inception_v2_coco.config

下载配置文件后,请确保用指向检查点文件、训练以及评估.tfrecord文件与标签映射文件的路径代替“PATH_TO_BE_CONFIGURED”字段。

需要配置的最终文件是labels.txt映射文件,其中包含所有不同目标的标签。由于我们只是在寻找一种类型的目标,我们的标签文件看起来像这样:

item {
  id: 1
  name: 'waldo'
}

最后,我们最终应该:

  • 具有.ckpt检查点文件的预训练模型;
  • 训练和评估.tfrecord数据集;
  • 标记映射文件;
  • 指向以上文件的管道配置文件。

现在,我们准备开始训练。

训练

Tensorflow目标检测API提供了一个简单易用的Python脚本来重新训练我们的模型。它位于models / research / object_detection中,可以利用下列路径运行:

python train.py –logtostderr –pipeline_config_path= PATH_TO_PIPELINE_CONFIG –train_dir=PATH_TO_TRAIN_DIR

其中PATH_TO_PIPELINE_CONFIG是到管道配置文件的路径,PATH_TO_TRAIN_DIR是一个新创建的目录,我们的新检查点和模型将被存储在该目录中。

train.py的输出应该如下所示:

用最重要的信息来查找损失。这是在训练或验证集中每个示例错误的总和。当然,你希望它尽可能低,这意味着,缓慢下降表示你的模型正在学习(或过度拟合你的训练数据)。你还可以使用Tensorboard来更详细地显示训练数据。

该脚本将在一定数量的步骤后自动存储检查点文件,以便你随时恢复保存的检查点,以防计算机在学习过程中崩溃。

这意味着当你想结束模型的训练时,你可以终止脚本。

但是什么时候停止学习?关于何时停止训练,原则上是当评估集的损失减少或非常低时(在我们的例子中低于0.01)。

测试

现在我们可以通过在一些示例图像上进行测试来实际使用我们的模型。

首先,我们需要使用models/research/object_detection脚本中存储的检查点(位于我们的训练目录中)导出推理图:

python export_inference_graph.py — pipeline_config_path PATH_TO_PIPELINE_CONFIG --trained_checkpoint_prefix PATH_TO_CHECPOINT --output_directory OUTPUT_PATH

我们的Python脚本可以用导出的推理图来查找Wally的位置。

我写了一些简单的Python脚本(基于Tensorflow 目标检测API),你可以在模型上使用它们执行目标检测,并在检测到的目标周围绘制框或将其暴露。

find_wally.py和find_wally_pretty.py都可以在我的Github仓库中找到,可以简单地运行:

  • find_wally.py地址: https://github.com/tadejmagajna/HereIsWally/blob/master/find_wally.py
  • find_wally_pretty.py地址: https://github.com/tadejmagajna/HereIsWally/blob/master/find_wally_pretty.py
  • Github repo 地址: https://github.com/tadejmagajna/HereIsWally
python find_wally.py

或者

python find_wally_pretty.py

在自己的模型或自己的评估图像上使用脚本时,请确保修改model_path和image_path变量。

结语

在我的Github repo 上发布的模型表现非常出色。

模型设法在评估图像中找到Wally,并且对网络上的一些额外的随机例子处理得很好。它未能找到很大的Wally,直观来说,找到小的walley应该更容易解决。这表明我们的模型可能过度适合我们的训练数据,主要是因为训练图像较少。

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-12-14

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏WD学习记录

21个项目玩转深度学习 学习笔记(2)

事实上,必须先读入数据后才能进行计算,假设读入用时0.1s,计算用时0.9秒,那么没过1s,GPU都会有0.1s无事可做,大大降低了运算的效率。

3441
来自专栏专知

【下载】PyTorch 实现的YOLO v2目标检测算法

【导读】目标检测是计算机视觉的重要组成部分,其目的是实现图像中目标的检测。YOLO是基于深度学习方法的端到端实时目标检测系统(YOLO:实时快速目标检测)。YO...

5136
来自专栏人人都是极客

5.训练模型之利用训练的模型识别物体

接下来我们开始训练,这里要做三件事: 将训练数据上传到训练服务器,开始训练。 将训练过程可视化。 导出训练结果导出为可用作推导的模型文件。 配置 Pipelin...

4084
来自专栏杨熹的专栏

TensorFlow-2: 用 CNN 识别数字

---- 本文结构: CNN 建立模型 code ---- 昨天只是用了简单的 softmax 做数字识别,准确率为 92%,这个太低了,今天用 CNN 来提高...

3635
来自专栏素质云笔记

SSD+caffe︱Single Shot MultiBox Detector 目标检测+fine-tuning(二)

承接上一篇SSD介绍:SSD+caffe︱Single Shot MultiBox Detector 目标检测(一) 如果自己要训练SSD模型呢,关键...

1.1K10
来自专栏计算机视觉与深度学习基础

【深度学习】使用tensorflow实现AlexNet

AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中所使用的很多技巧,...

45810
来自专栏专知

Tensorflow实战系列:手把手教你使用CNN进行图像分类(附完整代码)

【导读】专知小组计划近期推出Tensorflow实战系列,计划教大家手把手实战各项子任务。本教程旨在手把手教大家使用Tensorflow构建卷积神经网络(CNN...

4.4K4
来自专栏机器人网

深度学习三要素:数据、模型、计算

数据来源:主要通过对初始数据图片进行人工标注和机器标注。数据样本非常的重要,好的样本等于成功了一半。

932
来自专栏YoungGy

MMD_5a_Clustering

聚类概述 定义 距离的定义 算法的分类 启发式算法 概述 KEY POINTS 如何代表cluster 如何决定距离远近 没有欧氏距离怎么办 终止条件 总结 K...

2909
来自专栏mathor

“达观杯”文本智能处理挑战赛

 由于提供的数据集较大,一般运行时间再10到15分钟之间,基础电脑配置在4核8G的样子(越消耗内存在6.2G),因此,一般可能会遇到内存溢出的错误

3792

扫码关注云+社区

领取腾讯云代金券