如何用TensorFlow和Swift写个App识别霉霉?

访问flyai.club,一键创建你的人工智能项目

github | https://github.com/sararob/tswift-detection

在很多歌迷眼里,尤其是喜欢乡村音乐的人,“霉霉”Taylor Swift是一位极具辨识度也绝对不能错过的女歌手。在美国硅谷就有一位非常喜欢 Taylor Swift 的程序媛 Sara Robinson,同时她也是位很厉害的 APP 开发者。喜爱之情难以言表,于是利用机器学习技术开发了一款iOS 应用,可以随时随地识别出 Taylor Swift~~~

不卖关子了,妹子开发的这款应用效果如下:

可以说是“捕捉”Taylor Swift 的神器了。

那么她是怎么做出的?她主要用了谷歌的 TensorFlow Object Detection API,机器学习技术以及 Swift 语言。用 Swift 识别 Taylor Swift,没毛病。下面我们就看看妹子是怎么操作的:

TensorFlow Object Detection API 能让我们识别出照片中物体的位置,所以借助它可以开发出很多好玩又酷炫的应用。之前有不少人用它来识别物体,但我(作者Sara Robinson——译者注)还是对人比较感兴趣,正好手头也有不少人物照片,所以就琢磨着搞个能识别人脸的应用。作为“霉霉”的死忠粉,当然是先做一款识别 Taylor Swift 的应用啦!

下面我会分享从收集“霉霉”照片到制作使用预训练模型识别照片的 iOS 应用的大体步骤:

预处理照片:重新调整照片大小并打上标签,然后切分成训练集和测试集,最后将照片转为 Pascal VOC 格式

将照片转为 TFRecords,输入 TensorFlow Object Detection API

使用 MobileNet 在 CLoud ML Engine 上训练模型

用 Swift 开发一个 iOS 前端,能用预训练模型识别照片

下面是整体的架构示意图:

虽然看着有点麻烦,其实也不是很复杂。

在我详细介绍每个步骤前,有必要解释一些后面会提到的技术名词。

TensorFlow Object Detection API:一款基于 TensorFlow 的框架,用于识别图像中的物体。例如,你可以用很多猫咪照片训练它,训练完后如果你给它展示一张有猫咪的照片,它就会在它认为照片有猫咪的地方标出一个矩形框。

不过,训练识别物体的模型需要花费很长时间和很多数据。幸好 TensorFlow Object Detection 上有 5 个预训练模型,可以很方便的用于迁移学习。什么是迁移学习呢?打个比方,小孩子在刚开始学说话时,父母会让他们学习说很多东西的名字,如果说错了,会纠正他们的错误。比如,小孩第一次学习认识猫咪时,他们会看着爸妈指着猫咪说“猫咪”。这个过程不断重复就会加强他们大脑的学习路径。然后当他们学习怎么认出狗狗时,小孩就不需要再从头学习。他们可以利用和认出猫咪相同的识别过程,但是应用在不同的任务上。迁移学习的工作原理也是如此。

我虽然没时间找几千张标记了 Taylor Swift 名字的照片,然后训练一个模型,但是我可以利用从 TensorFlow Object Detection API 中预训练模型里提取出的特征,这些模型都是用几百万张图像训练而成,我只需调整模型的一些层级,就能用它们完成具体的图像识别任务,比如识别 Taylor Swift。

第一步:预处理照片

首先我从谷歌上下载了 200 张 Taylor Swift 的照片,然后将它们分成两个数据集:训练集和测试集。然后给照片添加标签。测试集用于测试模型识别训练中未见过的照片的准确率。为了让训练更省时一些,我写了个脚本重新调整了所有照片的大小,确保全部照片宽度不超过600px。

因为 Object Detection API 会告诉我们物体在照片中的位置,所以不能仅仅把照片和标签作为训练数据输入进去而已。你还需要输入一个边界框,可以识别出物体在照片中的位置,以及和边界框相关的标签(在我们的数据集中,只用到一个标签:tswift,也就是 Taylor Swift)。

为了给我们的照片生成边界框,我用了 Labelling,这是一个 Python 程序,能让你输入标签图像后为每个照片返回一个带边界框和相关标签的 xml 文件(我整个早上都趴在桌子上忙活着用 Labelling 给 Taylor Swift 的照片打标签,搞得从我旁边路过的人都以关爱智障的眼神望着我)。

最后我在每张照片上定义了一个边界框以及标签 tswift,如下所示:

Labelling 生成 xml 文件的过程如下所示:

现在我手中的照片有了边界框和标签,但是还需要把它们转成 TensorFlow 接受的格式—— TFRecord,图像的一种二进制表示形式。我根据 GitHub 上的一个代码库(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md)写了一个脚本完成这个工作。我的脚本代码地址:https://github.com/sararob/tswift-detection/blob/master/convert_to_tfrecord.py

如果你想用我的脚本,你需要克隆 tensorflow/models 代码库到本地(https://github.com/tensorflow/models),打包Object Detection API:

现在我们可以运行 TFRecord 脚本了。运行如下来自 tensorflow/models/research 目录的命令,输入如下标志(运行两次,一次用于训练数据,一次用于测试数据):

第二步:在 Cloud ML Engine 上训练 Taylor Swift 识别器

我其实也可以在自己的笔记本上训练模型,但这会很耗时间。我要是中途用电脑干点别的,训练就得被迫停止。所以,用云端最好!我们可以用云端训练我们的模型,几个小时就能搞定。然后我用了 Cloud ML Engine 训练我的模型,觉得比用自己的 GPU 都快。

设置 Cloud ML Engine

在所有照片都转为 TFRecord 格式后,我们就可以将它们上传到云端,开始训练。首先,我在 Google Cloud 终端上创建一个项目,启动 Cloud ML Engine:

然后我创建一个 Cloud Storage bucket,用来为模型打包所有资源。确保为 bucket 选择一个区域(不要选 multi-regional):

我在 bucket 中创建了一个 a/data 子目录,用来放置训练和测试用的 TFRecord 文件:

Object Detection API 也需要一个 pbtxt 文件,会将标签映射为一个整数 ID。因为我只有一个标签, 所以 ID 非常短。

添加 MobileNet 检查点用于迁移学习

我现在不是从头训练模型,所以我进行训练时需要指向我要用到的预训练模型。我选择了 MobileNet 模型,它是转为移动端优化了的一系列小型模型。Mobile 能够迅速进行训练和做出预测。我下载了训练中会用到的检查点(http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)。所谓检查点就是一个二进制文件,包含了训练过程中在具体点时TensorFlow模型的状态。下载和解压检查点后,你会看到它包含3个文件:

训练模型时,这些文件全都要用到,所以我把它们放在 Cloud Storage bucket 中的同一 data/ 目录中。

在进行训练工作前,还需要添加一个镜像文件。Object Detection 脚本需要一种方法来找到我们的模型检查点、标签地图和训练数据。我们会用一个配置文件完成这一步。对于这 5 个预训练模型,TF Object Detection 代码库中都有相应的配置文件示例。我选择了用于 MobileNet 模型的那个(https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_mobilenet_v1_coco.config),并更新了Cloud Storage bucket中有相应路径的全部PATH_TO_BE_CONFIGURED 文件夹。除了将我的模型和Cloud Storage中的数据连在一起外,配置文件还能为我的模型配置几个超参数,比如卷积大小、激活函数和时步等等。

在正式训练前,在 /data Cloud Storage bucket 中应该有以下全部文件:

此外,还需要在 bucket 中创建 train/ 和 eval/ 子目录——在执行训练和验证模型时, TensorFlow 写入模型检查点文件的地方。

现在我准备开始训练了,通过 gcloud 命令行工具就可以。注意,你需要从本地克隆 tensorflow/models/research,从该目录中运行训练脚本。

在训练时,我同时也启动了验证模型的工作,也就是用模型未见过的数据验证它的准确率:

通过导航至 Cloud 终端的 ML Engine 的 Jobs 部分,就可以查看模型的验证是否正在正确进行,并检查具体工作的日志:

第三步:部署模型进行预测

如果想将模型部署在 ML Engine 上,我需要将模型的检查点转换为 ProtoBuf。在我的 train/bucket 中,我可以看到从训练过程的几个点中保存出了检查点文件:

检查点文件的第一行会告诉我们最新的检查点路径——我会从本地在检查点中下载这3个文件。每个检查点应该是.index,.meta和.data文件。将它们保存在本地目录中,我就可以使用Objection Detection的export_inference_graph 脚本将它们转换为一个ProtoBuf。如果想运行如下脚本,你需要定义到达你的MobileNet 配置文件的本地路径,从训练阶段中下载的模型检查点的数量,以及你想将导出的图表写入的目录的名字:

在运行该脚本后,你应该会在你的 .pb 输出目录中看到一个 saved_model/ 目录。上传 save_model.pb 文件(不用管其它的生成文件)到你的 Cloud Storage bucket 中的 /data 目录中。

现在我们准备将模型部署到 ML Engine 上,首先用 gcloud 创建你的模型:

然后通过将模型指向你刚上传到Cloud Storage中的保存的模型ProtoBuf,创建你的模型的第一个版本。

等模型部署后,就可以用ML Engine的在线预测 API 来为一个新图像生成预测。

第四步:用 firebase 函数和 Swift 创建一个预测客户端

我用 Swift 写了一个 iOS 客户端,会对模型提出预测请求。客户端会将照片上传至 Cloud Storage,它会触发一个用 Node.js 提出预测请求的 Firebase 函数,并将结果预测照片和数据保存至 Cloud Storage 和 Firestore。

首先,在这个 Swift 客户端中我添加了一个按钮,让用户可以访问手机相册。用户选择照片后,会触发程序将照片上传至 Cloud Storage:

接着我写了在上传至 Cloud Storage bucket 中用于本项目的文件中触发的 firebase 函数,它会取用照片,以 base64 将其编码,然后发送至 ML Engine 用于预测。完整的函数代码请查看这里(https://github.com/sararob/tswift-detection/blob/master/firebase/functions/index.js)。

在 ML Engine 的回应这里,我们得到:

detection_boxes 如果模型识别出照片中有 Taylor Swift,我们用它来定义围绕 Taylor Swift的边界框

detection_scores 返回每个边界框的置信值。我只选用置信值分数高出 70% 的检测。

detection_classes 会告诉我们检测结果相关的标签 ID。在我们的这里例子中会一直只有一个 ID,因为只有一个标签。

在函数中,我用 detection_boxes 在照片上画出边界框以及置信度分数(如果检测到照片上有 Taylor Swift)。然后我将添加了边框的新照片保存至 Cloud Storage,并写出照片到 Cloud Firestore 的文件路径,这样我就能读取路径,在 iOS 应用中下载新照片(带有识别框):

最后,在 iOS 应用中我可以获取照片更新后的 Firestore 路径。如果发现有检测结果,就将照片下载,然后会把照片和检测置信分数展示在应用上。该函数会取代上面第一个 Swift 脚本中的注释:

终于!我们得到了一个能识别 Taylor Swift 的 iOS 应用!

当然,由于只用了 140 张照片训练模型,因此识别准确率不是很高,有时会出错。但是后面有时间的时候,我会用更多照片重新训练模型,提高识别正确率,在 App Store 上架这个应用。

结语

这篇文章信息量还是蛮大的,也想自己做一个这样的 APP,比如能识别抖森或者别的谁?下面就为你总结一下几个重要步骤:

预处理数据:收集目标的照片,用 Labelling 为照片添加标签,并生成带边界框的 xml 文件。然后用脚本将标记后的图像转为 TFRecord 格式。

训练和评估一个 Object Detection 模型:将训练数据和测试数据上传至 Cloud Storage,用Cloud ML Engine 进行训练和评估。

将模型部署到 ML Engine:用 gcloud CLI 将模型部署到 ML Engine。

发出预测请求:用 Firebase 函数向 ML Engine 模型在线发起预测请求。从 APP 到 Firebase Storage 的上传会触发 Firebase 函数。

— End —

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180522A08G3R00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券