Keras 图像分类实践:从爬取图片到构建卷积神经网络

网上有很多使用卷积神经网络来对图片分类的教程,但很多都是直接使用开放数据集来简单地调用模型。

但是通常在实际的需求中,需要自己去收集训练图片集,所以本文将从如何从百度图片抓取图片生成自己的图片数据集说起,到使用 Keras 搭建卷积神经网络,再到使用预训练模型提取特征值来改善我们的模型。

文末有完整的代码和数据集链接。

抓取百度图片

我这里使用 Chrome 的一个叫 Header Hacker 的扩展工具:

添加一个手机的 User-Agent:

可以抓到这样的链接:

可以看到翻页其实就是 pn 的变化,其他参数都可以固定。

返回的 json 内容较多,为方便我们直接取 thumburl ,即缩略图的地址下载:

抓取图片的脚本基本就是这样了,大概 40 来行搞定。

这脚本可以从百度图片爬小姐姐美图。

不过这里模仿 Kaggle 一个非常著名的数据集 cats-vs-dogs ( 猫狗大战 ),抓取了饺子和汤圆的图片,来弄个 jiaozi-vs-tangyuan 的数据集。( 额,过年的时候吃饺子和汤圆吃撑了 )

爬完后大概可以分别得到 1500+ 多张图片。

如果你信得过百度搜索质量和深度学习惊人的表征能力的话,可以直接将下载好的图片用于训练,但是有点洁癖的我还是人肉剔除了一些比较离谱的图片。 ( 比如七龙珠里面的饺子... )

搭建卷积神经网络

图片整理

整理成如下结构:

数据分为三个部分:train 是用于训练的数据集,validation 是训练过程中用来对模型进行验证的数据集,test 则是在模型生成后再对模型进行评估的数据集。

定义部分参数

图片增强

图片增强, 对于小数据量训练非常有用。实际上我们这里训练集只有 1400 张图片,而我们上面设置的样本数为 3200。通过 Keras 的图片增强功能实现生成新的训练图片。

https://zhuanlan.zhihu.com/p/30197320 这篇文章非常详细地说明了各种图片增强方式。

Callbacks

设置训练时的 callback 函数, 可以在训练过程中方便处理自定义的操作:

参考 https://keras-cn.readthedocs.io/en/latest/other/callbacks/

搭建卷积神经网络

经典的 conv+relu -> pooling,随便先弄个三层卷积看下效果:

训练了 25 epochs 之后的结果:

可以看到 val_acc 达到了 85% 左右。loss 下降也比较明显。

模型验证

准确率达到了 86.1%:

查看了分类错误的图片,可能对计算机来说稍微有点难度,比如这张:

看起来还是不错的,毕竟我们只是简单地搭个三层卷积网络训练了 25 epochs,就教会了计算机(大致)分辨饺子和汤圆。

有兴趣的话可以自行调参或者修改模型提升正确率。

使用预训练模型

keras.applications 提供了带有预训练权重的 Keras 模型,可以使用预训练模型来抽取图片的特征,然后再接上一个全连接层实现对饺子和汤圆的迁移学习。( 在一个通用的大数据集,比如 ImageNet 上进行一定量的训练后,再用针对性的小数据集进一步强化训练。)

加载预训练权重模型

这里使用 Keras 作者自己提出的 Xception V3 模型。

去掉最后的输出层,因为最后我们需要接上个全连接层。这样 Xception 提取的最后一层现在为 (batch_size, 7, 7, 2048):

第一次执行的时候会从 github 下载预训练权重文件。

特征值提取

训练集和验证集分别提取特征值,然后摊平 ( 其实就是 Flatten() ) 方便接下来的全连接层:

全连接层模型

这部分就相当简单了。

训练结果:

因为只是一个全连接层,训练起来速度比自己搭的卷积网络要快多了。

第一个 epoch 验证集的准确率就可以达到 90% 以上,不过之后有点抖动,最后 25 epochs 最好的 val_acc 达到了 95.75%:

最终的测试集准确率也达到了 97%:

这样,我们只花了不到十分钟就训练好了一个准确率达到 97% 的饺子汤圆分类器,效率可观!

资源地址

Jupyter Notebook 和图片数据集地址:https://github.com/jackhuntcn/notebooks

不定期更新 入门级及不靠谱的 数据抓取、数据分析、深度学习以及其他有趣脚本的原创文章,欢迎长按下面的二维码关注

  • 发表于:
  • 原文链接:http://kuaibao.qq.com/s/20180301G10MUB00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。

扫码关注云+社区

领取腾讯云代金券