【机器学习实践】水果图片分类

作者参加了gitchat图文课 机器学习极简入门课(扫描下面二维码可见)的学习。

学习理论之外,自己寻找资源动手实践,在实际做项目中巩固了习得的理论知识,并进一步体会到了日常积累的重要性。

下文是作者的一次实践总结。

实践项目

功能:

水果图片分类——构建一个分类模型,对水果图片进行分类。

数据来源:

https://www.kaggle.com/moltean/fruits

总共55244张100x100的图片,和81个分类(水果种类)

按照4:1的比例分为训练集和测试集

解决方案:

构建一个基于深度学习的多分类模型。

本次实践参考kaggle参赛者Esther S. Weon的Identifying Fruits in Images:

https://www.kaggle.com/esthersweon/identifying-fruits-in-images

搭建运行环境

作者的运行环境:

计算机内存:16G

操作系统:Windows10

显卡(GPU):NVIDIA GeForce GTX 1080

编程语言:Python3.5rc1

IDE:PyCharm

深度学习框架:Tensorflow 1.3.0-rc0,Keras 2.0.8

其他工具包:vs2015,cuda8,cudnn6

搭建过程提示:

1. 针对自己的系统安装好cuda与cudnn (搭建方法请自行百度)

2. 通过到python官网下载并安装好python3.5(我下载的是python-3.5.4rc1-amd64.exe)

3. 在安装GPU版本的tensorflow以及keras深度学习框架。

4. 在运行程序的过程中如果遇到缺少依赖包(dependency),可以通过命令行运行

pip(pip3) install ${缺少的安装包名}

来完成安装。

参考资料:

1. win10下基于python(anaconda)安装gpu版本的TensorFlow以及kears深度学习框架

https://blog.csdn.net/colourful_sky/article/details/78524382

2. win10 x64 系统 python2 和python3 共存

https://blog.csdn.net/yuanyuan95/article/details/64920497

3. win10下安装tendorflow注意事项

https://blog.csdn.net/bianjun1075/article/details/60478487

4. windows下安装tensorflowGPU版本报错:

OSError:[WinError 126] 找不到指定的模块/Could not find 'cudart64_90.dll'.

https://blog.csdn.net/wobeatit/article/details/79207196

模型的训练、测试和应用

训练模型

下载数据集Fruits360

https://www.kaggle.com/moltean/fruits/downloads/fruits.zip/30

然后运行模型训练代码

https://github.com/jianghongsun/Fruit-classification/blob/master/train.py

训练中可能出现的问题:在模仿作者进行多分类时,可能在加载图片代码运行到

X_train, y_train, train_labels = get_data(train_dir) X_test, y_test, test_labels = get_data(test_dir)

时,可能会报内存不足。

解决办法为:

1. 修改下面这句中的35为10(或者更小):

for idx, folder_name in enumerate( os.listdir(folder_path)[:35] ):

2. 开设虚拟内存。具体方法参见:

https://jingyan.baidu.com/article/6f2f55a1b834b6b5b83e6c48.html

3. 购买大的内存条。

测试模型

运行模型测试代码,用原装测试集测试模型:

https://github.com/jianghongsun/Fruit-classification/blob/master/test.py

下面这段代码用来适应作者的GPU配置:

#设置GPU的使用 os.environ["CUDA_VISIBLE_DEVICES"] = "0" config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.9 session = tf.Session(config=config) KTF.set_session(session)

运行结果确实和Esther S. Weon公布的相符,在原装测试集上,准确率达到98%。

应用模型

从互联网上下载一些水果图片,用模型来对其进行判断(inference)。

下面这9张图片是作者从互联网上下载的,resize到100x100后进行判断,结果列在图片下:

上述的识别结果可以看出不是十分理想,需要继续优化。

增大训练数据集

如果想得到一个通用性更好的模型进行分类,需要加大数据集。增加数据集的方法有如下:

1.旋转图像的角度;

2.亮度调节;

3.色度调节;

4.对比增强;

5.锐度变化;

6. 使用生成对抗网络(Generative Adversarial Nets)根据现有样本图片生成新图片。

收获和感想

通过这次动手实践,首先收获的是自信:只要自己动手,还是能倒腾出点东西的!

当然,距离项目的实际应用还需要继续努力。

就本次的实践而言,按照上面的步骤训练完的模型,虽然在测试上取得了很好的结果。训练的模型对随机下载的图片分类不太好,也就是通用性不太好。

Fruits360 图像集中所有的样本都是白色背景,与我们实际测试环境有差别。后续的训练中要考虑这方面带来的影响。需要增加训练样本或者别的其他策略来完成。

当我们遇到问题时,首先冷静,然后动手动脑,一方面自己尝试,另一方面从网络查找相关解决方法或请教大牛。

当然,更重要的是增大知识储备,平时多查阅资料,多问,多思考。积累充分才能在实际项目中得心应手。

另外,感谢:

  • 李烨老师开设的极简机器学习教程;
  • kaggle参赛者Esther S. Weon分享图片中水果的分类的代码;
  • 以及在学习上一路陪伴的小伙伴们。

最后,祝大家在机器学习的道路上越来越开心!

原文发布于微信公众号 - 悦思悦读(yuesiyuedu)

原文发表时间:2018-10-10

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

教程 | 如何使用TensorFlow实现音频分类任务

7347
来自专栏数据派THU

怎样构建中文文本标注工具?(附工具、代码、论文等资源)

来源:Paperweekly 本文长度为2218字,建议阅读4分钟 本文为你介绍中文文本标注工具的构建方法,并提供多个开源文本标注工具。 项目地址: https...

1.2K7
来自专栏目标检测和深度学习

英伟达开源数据增强和数据解码库,解决计算机视觉性能瓶颈

【新智元导读】在CVPR 2018大会上,英伟达开源了数据增强库DALI和数据解码库nvJPEG。

1344
来自专栏应兆康的专栏

19. 总结:基本错误分析

1351
来自专栏应兆康的专栏

19. 总结:基本错误分析

• 不要一开始就尝试设计和构建完美的系统,而是尽可能快的建立和训练一个基础的系统(几天之内),然后使用错误分析。帮助你找到最优的方向,并迭代改进你的算法。

3269
来自专栏机器之心

资源 | 一个基于PyTorch的目标检测工具箱,商汤联合港中文开源mmdetection

项目地址:https://github.com/open-mmlab/mmdetection

5322
来自专栏专知

【AlphaGo Zero 核心技术-深度强化学习教程代码实战03】编写通用的格子世界环境类

【导读】Google DeepMind在Nature上发表最新论文,介绍了迄今最强最新的版本AlphaGo Zero,不使用人类先验知识,使用纯强化学习,将价值...

2994
来自专栏新智元

【10大深度学习框架实验对比】Caffe2最优,TensorFlow排第6

【新智元导读】微软数据科学家Ilia Karmanov做了一个项目,使用高级API测试8种常用深度学习框架的性能(因为Keras有TF,CNTK和Theano,...

4157
来自专栏量子位

十分钟,我搞定了一个人物检测模型

人物检测确实是个老生常谈的话题了,自动驾驶中的道路行人检测、无人零售中的行为检测、时尚界的虚拟穿搭、安防界的人员监控、手机应用中的人脸检测……人物检测不易察觉,...

1445
来自专栏石瞳禅的互联网实验室

【TensorFlow实战——笔记】第2章:TensorFlow和其他深度学习框架的对比

可以看到各大主流框架基本都支持Python,目前Python在科学计算和数据挖掘领域可以说是独领风骚。虽然有来自R、Julia等语言的竞争压力,但是Python...

1231

扫码关注云+社区

领取腾讯云代金券