专栏首页Web技术布道师如何用TensorFlow和Swift写个App识别霉霉?

如何用TensorFlow和Swift写个App识别霉霉?

在很多歌迷眼里,尤其是喜欢乡村音乐的人,“霉霉”Taylor Swift是一位极具辨识度也绝对不能错过的女歌手。在美国硅谷就有一位非常喜欢 Taylor Swift 的程序媛 Sara Robinson,同时她也是位很厉害的 APP 开发者。喜爱之情难以言表,于是利用机器学习技术开发了一款iOS 应用,可以随时随地识别出 Taylor Swift~~~

不卖关子了,妹子开发的这款应用效果如下:

可以说是“捕捉”Taylor Swift 的神器了。

那么她是怎么做出的?她主要用了谷歌的 TensorFlow Object Detection API,机器学习技术以及 Swift 语言。用 Swift 识别 Taylor Swift,没毛病。下面我们就看看妹子是怎么操作的:

TensorFlow Object Detection API 能让我们识别出照片中物体的位置,所以借助它可以开发出很多好玩又酷炫的应用。之前有不少人用它来识别物体,但我(作者Sara Robinson——译者注)还是对人比较感兴趣,正好手头也有不少人物照片,所以就琢磨着搞个能识别人脸的应用。作为“霉霉”的死忠粉,当然是先做一款识别 Taylor Swift 的应用啦!

下面我会分享从收集“霉霉”照片到制作使用预训练模型识别照片的 iOS 应用的大体步骤:

  • 预处理照片:重新调整照片大小并打上标签,然后切分成训练集和测试集,最后将照片转为 Pascal VOC 格式
  • 将照片转为 TFRecords,输入 TensorFlow Object Detection API
  • 使用 MobileNet 在 CLoud ML Engine 上训练模型
  • 用 Swift 开发一个 iOS 前端,能用预训练模型识别照片

下面是整体的架构示意图:

虽然看着有点麻烦,其实也不是很复杂。

在我详细介绍每个步骤前,有必要解释一些后面会提到的技术名词。

TensorFlow Object Detection API:一款基于 TensorFlow 的框架,用于识别图像中的物体。例如,你可以用很多猫咪照片训练它,训练完后如果你给它展示一张有猫咪的照片,它就会在它认为照片有猫咪的地方标出一个矩形框。

不过,训练识别物体的模型需要花费很长时间和很多数据。幸好 TensorFlow Object Detection 上有 5 个预训练模型,可以很方便的用于迁移学习。什么是迁移学习呢?打个比方,小孩子在刚开始学说话时,父母会让他们学习说很多东西的名字,如果说错了,会纠正他们的错误。比如,小孩第一次学习认识猫咪时,他们会看着爸妈指着猫咪说“猫咪”。这个过程不断重复就会加强他们大脑的学习路径。然后当他们学习怎么认出狗狗时,小孩就不需要再从头学习。他们可以利用和认出猫咪相同的识别过程,但是应用在不同的任务上。迁移学习的工作原理也是如此。

我虽然没时间找几千张标记了 Taylor Swift 名字的照片,然后训练一个模型,但是我可以利用从 TensorFlow Object Detection API 中预训练模型里提取出的特征,这些模型都是用几百万张图像训练而成,我只需调整模型的一些层级,就能用它们完成具体的图像识别任务,比如识别 Taylor Swift。

提示:本项目全部代码地址见文末。

第一步:预处理照片

首先我从谷歌上下载了 200 张 Taylor Swift 的照片,然后将它们分成两个数据集:训练集和测试集。然后给照片添加标签。测试集用于测试模型识别训练中未见过的照片的准确率。为了让训练更省时一些,我写了个脚本重新调整了所有照片的大小,确保全部照片宽度不超过600px。

因为 Object Detection API 会告诉我们物体在照片中的位置,所以不能仅仅把照片和标签作为训练数据输入进去而已。你还需要输入一个边界框,可以识别出物体在照片中的位置,以及和边界框相关的标签(在我们的数据集中,只用到一个标签:tswift,也就是 Taylor Swift)。

为了给我们的照片生成边界框,我用了 Labelling,这是一个 Python 程序,能让你输入标签图像后为每个照片返回一个带边界框和相关标签的 xml 文件(我整个早上都趴在桌子上忙活着用 Labelling 给 Taylor Swift 的照片打标签,搞得从我旁边路过的人都以关爱智障的眼神望着我)。

最后我在每张照片上定义了一个边界框以及标签 tswift,如下所示:

Labelling 生成 xml 文件的过程如下所示:

<annotation>
<folder>Desktop</folder>
<filename>tswift.jpg</filename>
<path>/Desktop/tswift.jpg</path>
<source>
 <database>Unknown</database>
</source>
<size>
 <width>1000</width>
 <height>667</height>
 <depth>3</depth>
</size>
<segmented>0</segmented>
<object>
 <name>tswift</name>
 <pose>Unspecified</pose>
 <truncated>0</truncated>
 <difficult>0</difficult>
 <bndbox>
  <xmin>420</xmin>
  <ymin>80</ymin>
  <xmax>582</xmax>
  <ymax>291</ymax>
 </bndbox>
</object></annotation>

现在我手中的照片有了边界框和标签,但是还需要把它们转成 TensorFlow 接受的格式—— TFRecord,图像的一种二进制表示形式。我根据 GitHub 上的一个代码库(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/using_your_own_dataset.md)写了一个脚本完成这个工作。我的脚本代码地址:https://github.com/sararob/tswift-detection/blob/master/convert_to_tfrecord.py

如果你想用我的脚本,你需要克隆 tensorflow/models 代码库到本地(https://github.com/tensorflow/models),打包Object Detection API:

# From tensorflow/models/research/
python setup.py sdist
(cd slim && python setup.py sdist)

现在我们可以运行 TFRecord 脚本了。运行如下来自 tensorflow/models/research 目录的命令,输入如下标志(运行两次,一次用于训练数据,一次用于测试数据):

python convert_labels_to_tfrecords.py \--output_path=train.record \
--images_dir=path/to/your/training/images/ \--labels_dir=path/to/training/label/xml/

第二步:在 Cloud ML Engine 上训练 Taylor Swift 识别器

我其实也可以在自己的笔记本上训练模型,但这会很耗时间。我要是中途用电脑干点别的,训练就得被迫停止。所以,用云端最好!我们可以用云端训练我们的模型,几个小时就能搞定。然后我用了 Cloud ML Engine 训练我的模型,觉得比用自己的 GPU 都快。

设置 Cloud ML Engine

在所有照片都转为 TFRecord 格式后,我们就可以将它们上传到云端,开始训练。首先,我在 Google Cloud 终端上创建一个项目,启动 Cloud ML Engine:

然后我创建一个 Cloud Storage bucket,用来为模型打包所有资源。确保为 bucket 选择一个区域(不要选 multi-regional):

我在 bucket 中创建了一个 a/data 子目录,用来放置训练和测试用的 TFRecord 文件:

Object Detection API 也需要一个 pbtxt 文件,会将标签映射为一个整数 ID。因为我只有一个标签, 所以 ID 非常短。

添加 MobileNet 检查点用于迁移学习

我现在不是从头训练模型,所以我进行训练时需要指向我要用到的预训练模型。我选择了 MobileNet 模型,它是转为移动端优化了的一系列小型模型。Mobile 能够迅速进行训练和做出预测。我下载了训练中会用到的检查点(http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)。所谓检查点就是一个二进制文件,包含了训练过程中在具体点时TensorFlow模型的状态。下载和解压检查点后,你会看到它包含3个文件:

训练模型时,这些文件全都要用到,所以我把它们放在 Cloud Storage bucket 中的同一 data/ 目录中。

在进行训练工作前,还需要添加一个镜像文件。Object Detection 脚本需要一种方法来找到我们的模型检查点、标签地图和训练数据。我们会用一个配置文件完成这一步。对于这 5 个预训练模型,TF Object Detection 代码库中都有相应的配置文件示例。我选择了用于 MobileNet 模型的那个(https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_mobilenet_v1_coco.config),并更新了Cloud Storage bucket中有相应路径的全部PATH_TO_BE_CONFIGURED 文件夹。除了将我的模型和Cloud Storage中的数据连在一起外,配置文件还能为我的模型配置几个超参数,比如卷积大小、激活函数和时步等等。

在正式训练前,在 /data Cloud Storage bucket 中应该有以下全部文件:

此外,还需要在 bucket 中创建 train/ 和 eval/ 子目录——在执行训练和验证模型时, TensorFlow 写入模型检查点文件的地方。

现在我准备开始训练了,通过 gcloud 命令行工具就可以。注意,你需要从本地克隆 tensorflow/models/research,从该目录中运行训练脚本。

在训练时,我同时也启动了验证模型的工作,也就是用模型未见过的数据验证它的准确率:

通过导航至 Cloud 终端的 ML Engine 的 Jobs 部分,就可以查看模型的验证是否正在正确进行,并检查具体工作的日志:

第三步:部署模型进行预测

如果想将模型部署在 ML Engine 上,我需要将模型的检查点转换为 ProtoBuf。在我的 train/bucket 中,我可以看到从训练过程的几个点中保存出了检查点文件:

检查点文件的第一行会告诉我们最新的检查点路径——我会从本地在检查点中下载这3个文件。每个检查点应该是.index,.meta和.data文件。将它们保存在本地目录中,我就可以使用Objection Detection的export_inference_graph 脚本将它们转换为一个ProtoBuf。如果想运行如下脚本,你需要定义到达你的MobileNet 配置文件的本地路径,从训练阶段中下载的模型检查点的数量,以及你想将导出的图表写入的目录的名字:

# Run this script from tensorflow/models/research: python object_detection/export_inference_graph.py \    --input_type encoded_image_string_tensor \    --pipeline_config_path ${LOCAL_PATH_TO_MOBILENET_CONFIG} \    --trained_checkpoint_prefix model.ckpt-${CHECKPOINT_NUMBER} \--output_directory ${PATH_TO_YOUR_OUTPUT}.pb

在运行该脚本后,你应该会在你的 .pb 输出目录中看到一个 saved_model/ 目录。上传 save_model.pb 文件(不用管其它的生成文件)到你的 Cloud Storage bucket 中的 /data 目录中。

现在我们准备将模型部署到 ML Engine 上,首先用 gcloud 创建你的模型:

gcloud ml-engine models create tswift_detector

然后通过将模型指向你刚上传到Cloud Storage中的保存的模型ProtoBuf,创建你的模型的第一个版本。

等模型部署后,就可以用ML Engine的在线预测 API 来为一个新图像生成预测。

gcloud ml-engine versions create v1 --model=tswift_detector --origin=gs://${YOUR_GCS_BUCKET}/data  --runtime-version=1.4

第四步:用 firebase 函数和 Swift 创建一个预测客户端

我用 Swift 写了一个 iOS 客户端,会对模型提出预测请求。客户端会将照片上传至 Cloud Storage,它会触发一个用 Node.js 提出预测请求的 Firebase 函数,并将结果预测照片和数据保存至 Cloud Storage 和 Firestore。

首先,在这个 Swift 客户端中我添加了一个按钮,让用户可以访问手机相册。用户选择照片后,会触发程序将照片上传至 Cloud Storage:

let firestore = Firestore.firestore()func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [String : Any]) {   let imageURL = info[UIImagePickerControllerImageURL] as? URL
   let imageName = imageURL?.lastPathComponent
   let storageRef = storage.reference().child("images").child(imageName!)   storageRef.putFile(from: imageURL!, metadata: nil) { metadata, error in
       if let error = error {
           print(error)
       } else {
           print("Photo uploaded successfully!")
           // TODO: create a listener for the image's prediction data in Firestore
           }
       }
   }
   dismiss(animated: true, completion: nil)}

接着我写了在上传至 Cloud Storage bucket 中用于本项目的文件中触发的 firebase 函数,它会取用照片,以 base64 将其编码,然后发送至 ML Engine 用于预测。完整的函数代码请查看这里(https://github.com/sararob/tswift-detection/blob/master/firebase/functions/index.js)。

function cmlePredict(b64img, callback) {
   return new Promise((resolve, reject) => {
       google.auth.getApplicationDefault(function (err, authClient, projectId) {
           if (err) {
               reject(err);
           }
           if (authClient.createScopedRequired && authClient.createScopedRequired()) {
               authClient = authClient.createScoped([
                   'https://www.googleapis.com/auth/cloud-platform'
               ]);
           }
           var ml = google.ml({
               version: 'v1'
           });
           const params = {
               auth: authClient,
               name: 'projects/sara-cloud-ml/models/tswift_detector',
               resource: {
                   instances: [
                   {
                       "inputs": {
                       "b64": b64img
                       }
                   }
                   ]
               }
           };           ml.projects.predict(params, (err, result) => {
               if (err) {
                   reject(err);
               } else {
                   resolve(result);
               }
           });
       });
   });
}

在 ML Engine 的回应这里,我们得到:

  • detection_boxes 如果模型识别出照片中有 Taylor Swift,我们用它来定义围绕 Taylor Swift的边界框
  • detection_scores 返回每个边界框的置信值。我只选用置信值分数高出 70% 的检测。
  • detection_classes 会告诉我们检测结果相关的标签 ID。在我们的这里例子中会一直只有一个 ID,因为只有一个标签。

在函数中,我用 detection_boxes 在照片上画出边界框以及置信度分数(如果检测到照片上有 Taylor Swift)。然后我将添加了边框的新照片保存至 Cloud Storage,并写出照片到 Cloud Firestore 的文件路径,这样我就能读取路径,在 iOS 应用中下载新照片(带有识别框):

const admin = require('firebase-admin');
admin.initializeApp(functions.config().firebase);
const db = admin.firestore();let outlinedImgPath = `outlined_img/${filePath.slice(7)}`;
let imageRef = db.collection('predicted_images').doc(filePath);imageRef.set({
   image_path: outlinedImgPath,
   confidence: confidence
});bucket.upload('/tmp/path/to/new/image', {destination: outlinedImgPath});

最后,在 iOS 应用中我可以获取照片更新后的 Firestore 路径。如果发现有检测结果,就将照片下载,然后会把照片和检测置信分数展示在应用上。该函数会取代上面第一个 Swift 脚本中的注释:

self.firestore.collection("predicted_images").document(imageName!)
   .addSnapshotListener { documentSnapshot, error in
       if let error = error {
           print("error occurred\(error)")
       } else {
           if (documentSnapshot?.exists)! {
               let imageData = (documentSnapshot?.data())
               self.visualizePrediction(imgData: imageData)
           } else {
               print("waiting for prediction data...")
           }
       }
}

终于!我们得到了一个能识别 Taylor Swift 的 iOS 应用!

当然,由于只用了 140 张照片训练模型,因此识别准确率不是很高,有时会出错。但是后面有时间的时候,我会用更多照片重新训练模型,提高识别正确率,在 App Store 上架这个应用。

结语

这篇文章信息量还是蛮大的,也想自己做一个这样的 APP,比如能识别抖森或者别的谁?下面就为你总结一下几个重要步骤:

  • 预处理数据:收集目标的照片,用 Labelling 为照片添加标签,并生成带边界框的 xml 文件。然后用脚本将标记后的图像转为 TFRecord 格式。
  • 训练和评估一个 Object Detection 模型:将训练数据和测试数据上传至 Cloud Storage,用Cloud ML Engine 进行训练和评估。
  • 将模型部署到 ML Engine:用 gcloud CLI 将模型部署到 ML Engine。
  • 发出预测请求:用 Firebase 函数向 ML Engine 模型在线发起预测请求。从 APP 到 Firebase Storage 的上传会触发 Firebase 函数。

本项目代码地址:

https://github.com/sararob/tswift-detection

本文分享自微信公众号 - PHP技术大全(phpgod)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-05-21

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Hybrid App 应用开发中 9 个必备知识点复习

    我们大前端团队内部 ?每周一练 的知识复习计划继续加油,本篇文章是 《Hybrid APP 混合应用专题》 主题的第二期和第三期的合集。

    ConardLi
  • 【总结】ios端被忽略的文件容错测试

    iphone沙盒模型的有四个文件夹:分别是 documents,Library,tmp,app包。手动保存的文件在documents文件里,NSUserdef...

    用户5521279
  • 分享集锦:设计模式讲解、Node.js 教程、Swift UI、Java 开发

    最近这段时间比较忙,产出内容频率低了一些,等这周忙完后,后面会抽空写几篇 GitHub 专题文章,敬请期待。

    GitHubDaily
  • UILabel显示定时器文本的跳动问题解决方案

    上面的gif图会发现在显示验证码计数时出现跳动和闪烁的问题。目前大多数用来实现定时器显示的控件都是UILabel。

    欧阳大哥2013
  • Flutter + MVP +Kotlin 实战!

    Kotlin,由 JetBrains 于 2011.07 推出,一款面向 JVM 在 Java 虚拟机上运行的静态类型编程语言。

    CCCruch
  • 《挑战30天C++入门极限》C++的iostream标准库介绍(1)

      我们所熟悉的输入输出操作分别是由istream(输入流)和ostream(输出流)这两个类提供的,为了允许双向的输入/输出,由istream和ostre...

    landv
  • iOS开发:苹果2018最新款手机(iPhone XS Max、iPhone XR等)如何查看并获取手机的UDID

    随着苹果设备的不断更新,作为苹果开发者来说既是好事也是坏事,好事是因为苹果设备的更新换代,淘汰了一部分旧设备更新了新设备,坏事就是要不断学习应对新的设备带来的...

    三掌柜
  • 纯代码实现matlabのGUI界面搭建

    图形用户界面 (Graphical User Interface,简称 GUI),是有别于纯代码执行,GUI能够繁琐的代码浓缩到一块简洁的界面上,用户只需要输输...

    艾木樨
  • 浅谈几种设计模式

    策略模式、模板方法模式、观察者模式、迭代子模式、责任链模式、命令模式、备忘录模式、状态模式、访问者模式、中介者模式、解释器模式。

    用户4143945
  • 腾讯社招iOS面试记录

    毕业好几年了,上周发送了简历给腾讯,参加了腾讯面试。具体部门这边就不说了。这次面试还是收获到了很多。

    iOSSir

扫码关注云+社区

领取腾讯云代金券