C-SATS工程副总裁教你如何用TensorFlow分类图像 part1

最近在深度学习算法和硬件性能方面的最新进展使研究人员和公司在图像识别,语音识别,推荐引擎和机器翻译等领域取得了巨大的进步。六年前,首次机器在视觉模式识别方面的表现首次超过人类。两年前,Google Brain团队发布了TensorFlow,让深度学习可以应用于大众。TensorFlow超越了许多用于深度学习的复杂工具。

有了TensorFlow,你可以访问具有强大功能的复杂特征。它之所以如此强大,是因为TensorFlow的易用性非常好。

本文由两部分组成,我将介绍如何快速创建用于实际图像识别的卷积神经网络。计算步骤是Embarrassingly parallel和可部署执行逐帧视频分析和temporal-aware视频分析。

这个系列直接讲解最重要的地方。对于命令行和Python的基本理解需要你自己研究。写这篇文章的目的是让大家可以快速入门,并激励大家创建自己的项目。

运行原理

我们将按照以下步骤操作:

1. 标记是管理训练数据的过程。对于花卉,将雏菊的图像拖入“雏菊”目录,将玫瑰拖入“玫瑰”目录等等,以便根据需要选择许多不同的花朵。如果我们不去标记“蕨类植物”,分类器也永远不会返回“蕨类植物”。每个类型都需要大量的例子,所以这是一个重要的但很耗时的过程。为了省时,在这里我们使用预先标记好的数据。

2. 训练是将标记后的数据(图像)输入到模型中。工具将抓取一组随机图像,使用模型来猜测每种花的类型,测试猜测的准确性,并重复此过程,直到大部分训练数据被使用。最后一部分未过使用的图像用于计算训练模型的准确性。

3. 分类是使用模型分类新的图像。例如,输入:IMG207.JPG,输出:雏菊。这是最快,最简单的一步。

训练和分类

在本教程中,我们将训练图像分类器来识别不同类型的花朵。深度学习需要大量的训练数据,所以我们需要大量的分类好的花卉图像。值得庆幸的是,我有现成的,所以我会使用带有很好脚本的分类后的数据集,并使用一个现有的、经过完全训练的图像分类模型,并重新训练模型的最后几层。这种技术被称为迁移学习。

我们正在进行再培训的模型被称为Inception v3,它的介绍论文如下。

  • 介绍论文:https://arxiv.org/abs/1512.00567

从不知道如何从雏菊中分辨出郁金香到训练后可以成功分辨,大约需要20分钟。这就是深度学习的“学习”部分。

安装配置

首先在你选择的平台上安装Docker。

  • https://www.docker.com/community-edition#/download

docker是唯一一个依赖项。在许多TensorFlow教程中也用到了docker(这应该表明这是一个合理的方法)。我也更喜欢这种安装TensorFlow的方法,因为它通过不需要安装一堆依赖项,可以保持主机(笔记本电脑或桌面)的整洁。

安装Docker后,我们准备启动一个TensorFlow容器(container)进行训练和分类。创建一个工作目录在你的硬盘上准备2GB的空闲空间。创建一个名为local的子目录并记录访问这个目录的完整路径。

docker run -v /path/to/local:/notebooks/local --rm -it --name tensorflow 
tensorflow/tensorflow:nightly /bin/bash

以下是这个命令详细解释。

  • -v /path/to/local:/notebooks/local加载你刚刚创建的local目录到容器中合适的位置。如果你使用RHEL,Fedora或其他支持SELinux的系统,附加:Z到允许容器访问目录。(https://www.projectatomic.io/blog/2015/06/using-volumes-with-docker-can-cause-problems-with-selinux/)
  • –rm 告诉Docker在完成后删除容器。
  • -it 附加我们的输入和输出以使容器有交互性。
  • –name tensorflow将我们的容器命名为tensorflow
  • tensorflow/tensorflow:表示从Docker Hub(公共镜像库)的tensorflow/tensorflow中运行nightly而不是最新的镜像(默认是运行最新的)。之所以不用最新的,是因为在撰写本文时最新的包含了破坏TensorBoard的bug。而我们稍后要用TensorBoard进行可视化。
  • /bin/bash表示不运行默认命令;而是运行一个Bash shell。

训练模型

在容器内部,运行这些命令下载并检查训练数据。

curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
echo 'db6b71d5d3afff90302ee17fd1fefc11d57f243f  flower_photos.tgz' | sha1sum -c

如果你没有看到消息flower_photos.tgz: OK,则表示没有正确的文件。如果上述curl或sha1sum步骤失败,请手动下载并分解主机local目录中的训练数据tarball(SHA-1 checksum: db6b71d5d3afff90302ee17fd1fefc11d57f243f)。

现在把训练数据放在适当的地方,然后下载和理智检查再训练脚本。

mv flower_photos.tgz local/
cd local
curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/
10cf65b48e1b2f16eaa82
6d2793cb67207a085d0/tensorflow/examples/image_retraining/retrain.py
echo 'a74361beb4f763dc2d0101cfe87b672ceae6e2f5  retrain.py' | sha1sum -c

查到并确认retrain.py具有正确内容。你会看到retrain.py: OK。

运行再训练脚本。

python retrain.py --image_dir flower_photos --output_graph output_graph.pb 
--output_labels output_labels.txt

如果遇到以下错误,忽略即可。

TypeError: not all arguments converted during string formatting Logged from file
tf_logging.py, line 82

执行retrain.py后,训练图像被自动成的训练、测试和验证数据集。

在输出中,我们希望“训练准确性”和“验证准确性”高一些,“交叉熵”低一些。有关这些术语的详细解释,请访问下方链接。在较好的硬件上的训练需要大约30分钟。

  • 术语:https://www.tensorflow.org/tutorials/image_retraining

看一看你的控制台输出的最后一行:

INFO:tensorflow:Final test accuracy = 89.1% (N=340)

这说明我们的模型十次中有九次能够正确地猜出给定图像中显示的使五种花型中的哪一种。由于训练过程中加入了随机性,你的准确性可能会有所不同。

分类

再加上一个小脚本,我们可以将新的花朵图像添加到模型中,并输出它的猜测。这就是图像分类。

在主机上的local目录中将以下代码保存成classify.py:

import tensorflow as tf, sys
 
image_path = sys.argv[1]
graph_path = 'output_graph.pb'
labels_path = 'output_labels.txt'
 
# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()
 
# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
    in tf.gfile.GFile(labels_path)]
 
# Unpersists graph from file
with tf.gfile.FastGFile(graph_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')
 
# Feed the image_data as input to the graph and get first prediction
with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    predictions = sess.run(softmax_tensor, 
    {'DecodeJpeg/contents:0': image_data})
    # Sort to show labels of first prediction in order of confidence
    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
    for node_id in top_k:
         human_string = label_lines[node_id]
         score = predictions[0][node_id]
         print('%s (score = %.5f)' % (human_string, score))

要测试自己的图像,将其在你的local目录中保存为test.jpg,并在容器中运行python classify.py test.jpg。输出结果如下所示:

sunflowers (score = 0.78311)
daisy (score = 0.20722)
dandelion (score = 0.00605)
tulips (score = 0.00289)
roses (score = 0.00073)

数字表明自信程度。模型有78.311%的确定图像中的花是向日葵。得分越高表示图像越匹配结果。请注意,只显示一个匹配。多标签分类需要不同的方法。

欲了解更多详情,查看此大线,由线解释的classify.py。

分类器脚本中的图形加载代码损坏了,所以我应用了graph_def = tf.GraphDef()等图形加载代码。

我们创造了一个还可以的花朵图像分类器,可以在笔记本电脑上每秒钟处理大约五个图像。

在下一期中,我们将用到这些知识训练不同的图像分类器,并使用TensorBoard观察它。如果你想试试TensorBoard,请保持容器的运行,并确保docker运行没有被终止。

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-12-23

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

深度 | 详解苹果Core ML:如何为iOS创建机器学习应用?

选自developer.apple 机器之心编译 参与:吴攀 在昨天开幕的 WWDC 2017 开发者大会上,苹果宣布了一系列新的面向开发者的机器学习 API...

3357
来自专栏ATYUN订阅号

Machine Box创始人教你快速建立一个ML图像分类器

AiTechYun 编辑:Yining Machine Box的创始人Mat Ryer在medium上分享了一篇博文,意在教你在硬盘上快速的建立一个机器学习图像...

3326
来自专栏人工智能LeadAI

谷歌开放的TensorFlow Object Detection API 效果如何?对业界有什么影响?

熟悉TensorFlow的人都知道,tf在Github上的主页是: https://github.com/tensorflow , 然后这个主页下又有两个比较...

3578
来自专栏机器学习人工学weekly

机器学习人工学weekly-2018/9/23

Rosetta: Understanding text in images and videos with machine learning

685
来自专栏YoungGy

MMD_6b_DecisionTree

overview ? construct 构建决策树的时候需要考虑以下问题: 什么时候停止 如果不停止,那么以什么变量的什么特征构建二叉树 如果停止,那么预测的...

1787
来自专栏素质云笔记

caffe︱ImageData层、DummyData层作为原始数据导入的应用

Part1:caffe的ImageData层 ImageData是一个图像输入层,该层的好处是,直接输入原始图像信息就可以导入分析。 在案例中利用Image...

33710
来自专栏SnailTyan

动手学深度学习——第一课笔记(上)

第一课:从上手到多类分类 课程首先介绍了深度学习的很多应用:例如增强学习、物体识别、语音识别、机器翻译、推荐系统、广告点击预测等。 课程目的:通过动手实现来理解...

2140
来自专栏生信宝典

2018 升级版Jaspar数据库

R包ggseqlogo 绘制seq logo图和Seq logo 在线绘制工具—Weblogo介绍了如何用R脚本和在线工具绘制seq logo图,用于展现转录因...

1122
来自专栏小詹同学

人脸检测——笑脸检测

前边已经详细介绍过人脸检测,其实检测类都可以归属于同一类,毕竟换汤不换药!无论是人脸检测还是笑脸检测,又或者是opencv3以后版本加入的猫脸检测...

4757
来自专栏ATYUN订阅号

Github项目推荐:新型深度网络体系结构去除图像中的雨水痕迹

雨水痕迹会严重降低图像能见度,导致许多当前的计算机视觉算法无法工作。因此去除图像中的雨水是有必要的。

702

扫码关注云+社区