C-SATS工程副总裁教你如何用TensorFlow分类图像 part1

最近在深度学习算法和硬件性能方面的最新进展使研究人员和公司在图像识别,语音识别,推荐引擎和机器翻译等领域取得了巨大的进步。六年前,首次机器在视觉模式识别方面的表现首次超过人类。两年前,Google Brain团队发布了TensorFlow,让深度学习可以应用于大众。TensorFlow超越了许多用于深度学习的复杂工具。

有了TensorFlow,你可以访问具有强大功能的复杂特征。它之所以如此强大,是因为TensorFlow的易用性非常好。

本文由两部分组成,我将介绍如何快速创建用于实际图像识别的卷积神经网络。计算步骤是Embarrassingly parallel和可部署执行逐帧视频分析和temporal-aware视频分析。

这个系列直接讲解最重要的地方。对于命令行和Python的基本理解需要你自己研究。写这篇文章的目的是让大家可以快速入门,并激励大家创建自己的项目。

运行原理

我们将按照以下步骤操作:

1. 标记是管理训练数据的过程。对于花卉,将雏菊的图像拖入“雏菊”目录,将玫瑰拖入“玫瑰”目录等等,以便根据需要选择许多不同的花朵。如果我们不去标记“蕨类植物”,分类器也永远不会返回“蕨类植物”。每个类型都需要大量的例子,所以这是一个重要的但很耗时的过程。为了省时,在这里我们使用预先标记好的数据。

2. 训练是将标记后的数据(图像)输入到模型中。工具将抓取一组随机图像,使用模型来猜测每种花的类型,测试猜测的准确性,并重复此过程,直到大部分训练数据被使用。最后一部分未过使用的图像用于计算训练模型的准确性。

3. 分类是使用模型分类新的图像。例如,输入:IMG207.JPG,输出:雏菊。这是最快,最简单的一步。

训练和分类

在本教程中,我们将训练图像分类器来识别不同类型的花朵。深度学习需要大量的训练数据,所以我们需要大量的分类好的花卉图像。值得庆幸的是,我有现成的,所以我会使用带有很好脚本的分类后的数据集,并使用一个现有的、经过完全训练的图像分类模型,并重新训练模型的最后几层。这种技术被称为迁移学习。

我们正在进行再培训的模型被称为Inception v3,它的介绍论文如下。

  • 介绍论文:https://arxiv.org/abs/1512.00567

从不知道如何从雏菊中分辨出郁金香到训练后可以成功分辨,大约需要20分钟。这就是深度学习的“学习”部分。

安装配置

首先在你选择的平台上安装Docker。

  • https://www.docker.com/community-edition#/download

docker是唯一一个依赖项。在许多TensorFlow教程中也用到了docker(这应该表明这是一个合理的方法)。我也更喜欢这种安装TensorFlow的方法,因为它通过不需要安装一堆依赖项,可以保持主机(笔记本电脑或桌面)的整洁。

安装Docker后,我们准备启动一个TensorFlow容器(container)进行训练和分类。创建一个工作目录在你的硬盘上准备2GB的空闲空间。创建一个名为local的子目录并记录访问这个目录的完整路径。

docker run -v /path/to/local:/notebooks/local --rm -it --name tensorflow 
tensorflow/tensorflow:nightly /bin/bash

以下是这个命令详细解释。

  • -v /path/to/local:/notebooks/local加载你刚刚创建的local目录到容器中合适的位置。如果你使用RHEL,Fedora或其他支持SELinux的系统,附加:Z到允许容器访问目录。(https://www.projectatomic.io/blog/2015/06/using-volumes-with-docker-can-cause-problems-with-selinux/)
  • –rm 告诉Docker在完成后删除容器。
  • -it 附加我们的输入和输出以使容器有交互性。
  • –name tensorflow将我们的容器命名为tensorflow
  • tensorflow/tensorflow:表示从Docker Hub(公共镜像库)的tensorflow/tensorflow中运行nightly而不是最新的镜像(默认是运行最新的)。之所以不用最新的,是因为在撰写本文时最新的包含了破坏TensorBoard的bug。而我们稍后要用TensorBoard进行可视化。
  • /bin/bash表示不运行默认命令;而是运行一个Bash shell。

训练模型

在容器内部,运行这些命令下载并检查训练数据。

curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
echo 'db6b71d5d3afff90302ee17fd1fefc11d57f243f  flower_photos.tgz' | sha1sum -c

如果你没有看到消息flower_photos.tgz: OK,则表示没有正确的文件。如果上述curl或sha1sum步骤失败,请手动下载并分解主机local目录中的训练数据tarball(SHA-1 checksum: db6b71d5d3afff90302ee17fd1fefc11d57f243f)。

现在把训练数据放在适当的地方,然后下载和理智检查再训练脚本。

mv flower_photos.tgz local/
cd local
curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/
10cf65b48e1b2f16eaa82
6d2793cb67207a085d0/tensorflow/examples/image_retraining/retrain.py
echo 'a74361beb4f763dc2d0101cfe87b672ceae6e2f5  retrain.py' | sha1sum -c

查到并确认retrain.py具有正确内容。你会看到retrain.py: OK。

运行再训练脚本。

python retrain.py --image_dir flower_photos --output_graph output_graph.pb 
--output_labels output_labels.txt

如果遇到以下错误,忽略即可。

TypeError: not all arguments converted during string formatting Logged from file
tf_logging.py, line 82

执行retrain.py后,训练图像被自动成的训练、测试和验证数据集。

在输出中,我们希望“训练准确性”和“验证准确性”高一些,“交叉熵”低一些。有关这些术语的详细解释,请访问下方链接。在较好的硬件上的训练需要大约30分钟。

  • 术语:https://www.tensorflow.org/tutorials/image_retraining

看一看你的控制台输出的最后一行:

INFO:tensorflow:Final test accuracy = 89.1% (N=340)

这说明我们的模型十次中有九次能够正确地猜出给定图像中显示的使五种花型中的哪一种。由于训练过程中加入了随机性,你的准确性可能会有所不同。

分类

再加上一个小脚本,我们可以将新的花朵图像添加到模型中,并输出它的猜测。这就是图像分类。

在主机上的local目录中将以下代码保存成classify.py:

import tensorflow as tf, sys
 
image_path = sys.argv[1]
graph_path = 'output_graph.pb'
labels_path = 'output_labels.txt'
 
# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()
 
# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
    in tf.gfile.GFile(labels_path)]
 
# Unpersists graph from file
with tf.gfile.FastGFile(graph_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')
 
# Feed the image_data as input to the graph and get first prediction
with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    predictions = sess.run(softmax_tensor, 
    {'DecodeJpeg/contents:0': image_data})
    # Sort to show labels of first prediction in order of confidence
    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
    for node_id in top_k:
         human_string = label_lines[node_id]
         score = predictions[0][node_id]
         print('%s (score = %.5f)' % (human_string, score))

要测试自己的图像,将其在你的local目录中保存为test.jpg,并在容器中运行python classify.py test.jpg。输出结果如下所示:

sunflowers (score = 0.78311)
daisy (score = 0.20722)
dandelion (score = 0.00605)
tulips (score = 0.00289)
roses (score = 0.00073)

数字表明自信程度。模型有78.311%的确定图像中的花是向日葵。得分越高表示图像越匹配结果。请注意,只显示一个匹配。多标签分类需要不同的方法。

欲了解更多详情,查看此大线,由线解释的classify.py。

分类器脚本中的图形加载代码损坏了,所以我应用了graph_def = tf.GraphDef()等图形加载代码。

我们创造了一个还可以的花朵图像分类器,可以在笔记本电脑上每秒钟处理大约五个图像。

在下一期中,我们将用到这些知识训练不同的图像分类器,并使用TensorBoard观察它。如果你想试试TensorBoard,请保持容器的运行,并确保docker运行没有被终止。

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-12-23

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

谷歌正式开源 Hinton 胶囊理论代码,即刻用 TensorFlow 实现吧

雷锋网(公众号:雷锋网) AI 研习社消息,相信大家对于「深度学习教父」Geoffery Hinton 在去年年底发表的胶囊网络还记忆犹新,在论文 Dynami...

3116
来自专栏ATYUN订阅号

将Keras权值保存为动画视频,更好地了解模型是如何学习的

将Keras权值矩阵保存为简短的动画视频,从而更好地理解你的神经网络模型是如何学习的。下面是第一个LSTM层的例子,以及一个经过一个学习周期训练的6级RNN模型...

3464
来自专栏量子位

Keras 2正式发布,推出深度整合进TensorFlow的新API

允中 编译整理 量子位·QbitAI 出品 今天,深度学习框架Keras在博客上发表文章,介绍了深度整合进TensorFlow的内部版本tf.keras,以及...

3358
来自专栏量化投资与机器学习

深度学习项目

Github上比较受欢迎的深度学习项目(Top Deep Learning Projects),按照获得星星个数的排名,包括一些教程项目等。 ? ? ? ?

1906
来自专栏ATYUN订阅号

【深度学习】图片风格转换应用程序:使用CoreML创建Prisma

WWDC 2017让我们了解了苹果公司对机器学习的看法以及它在移动设备上的应用。CoreML框架使得将ML模型引入iOS应用程序变得非常容易。 ? 大约一年前,...

4558
来自专栏AI研习社

Github 项目推荐 | ANSI C 的简单神经网络库

Genann是一个经过精心测试的库,用于在 C 中训练和使用前馈人工神经网络(ANN)。它的主要特点是简单、快速、可靠和可魔改(hackable),它只需要提供...

721
来自专栏AI研习社

Github 项目推荐 | 用 JavaScript 实现的神经网络 —— brain.js

不过,一般的开发者应该都不会用神经网络来实现异或的功能吧,所以这里有一个更加实际的例子:训练一个神经网络来识别颜色对比 https://brain.js.org...

1152
来自专栏机器之心

开源 | 深度安卓恶意软件检测系统:用卷积神经网络保护你的手机

选自GitHub 机器之心编译 参与:Panda 恶意软件可以说是我们现代生活的一大威胁,为了保护我们电子设备中的财产和资料安全,我们往往需要寻求安全软件的帮助...

2807
来自专栏AI科技大本营的专栏

重磅消息 | 深度学习框架竞争激烈 TensorFlow也支持动态计算图

今晨 Google 官方发布消息,称 TensorFlow 支持动态计算图。 原文如下: 在大部分的机器学习中,用来训练和分析的数据需要经过一个预处理过程,输入...

2685
来自专栏YoungGy

ML基石_9_LinearRegression

linear regression problem linear regression algorithm 优化问题 求梯度 算法 generalization...

2346

扫码关注云+社区