前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用VGG模型自定义图像分类任务

使用VGG模型自定义图像分类任务

作者头像
chaibubble
发布2019-05-26 20:04:46
1.6K0
发布2019-05-26 20:04:46
举报

前言

网上关于VGG模型的文章有很多,有介绍算法本身的,也有代码实现,但是很多代码只给出了模型的结构实现,并不包含数据准备的部分,这让人很难愉快的将代码迁移自己的任务中。为此,这篇博客接下来围绕着如何使用VGG实现自己的图像分类任务,从数据准备到实验验证。代码基于Python与TensorFlow实现,模型结构采用VGG-16,并且将很少的出现算法和理论相关的东西。

数据准备

下载数据和转换代码

大多数人自己的训练数据,一般都是传统的图片形式,如.jpg,.png等等,而图像分类任务的话,这些图片的天然组织形式就是一个类别放在一个文件夹里,那么有啥大众化的数据集是这样的组织形式呢?TensorFlow的FlowersData,它下载下来是这个样子:

这里写图片描述
这里写图片描述

一共有五类,每一类中都有几百张图,我们把这些数据组织成TFrecord形式,对应的博客在这里,源码的github在这里,FlowersData数据集在这里。 有上面这三个东西之后,就可以生成TFrecord文件了。

组织图片数据

首先将FlowersData文件夹下的数据分成两个部分,训练数据和测试数据,我把原文件五个类别中都拿出大概100张图左右,数据的构成和路径如下:

这里写图片描述
这里写图片描述

生成训练TFrecord

代码语言:javascript
复制
#图片路径
cwd = 'F:\\flowersdata\\trainimages\\'
代码语言:javascript
复制
#文件路径
filepath = 'F:\\flowersdata\\tfrecord\\train\\'
代码语言:javascript
复制
classes=['daisy',
         'dandelion',
         'roses',
         'sunflowers',
         'tulips']
代码语言:javascript
复制
#tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
代码语言:javascript
复制
#tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)

生成效果:

这里写图片描述
这里写图片描述

生成预测TFrecord

代码语言:javascript
复制
#图片路径
cwd = 'F:\\flowersdata\\testimages\\'
代码语言:javascript
复制
#文件路径
filepath = 'F:\\flowersdata\\tfrecord\\test\\'
代码语言:javascript
复制
classes=['daisy',
         'dandelion',
         'roses',
         'sunflowers',
         'tulips']
代码语言:javascript
复制
#tfrecords格式文件名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
代码语言:javascript
复制
#tfrecords格式文件名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)

生成效果:

这里写图片描述
这里写图片描述

训练模型

初始权重与源码下载

VGG-16的初始权重我上传到了百度云,在这里下载; VGG-16源码我上传到了github,在这里下载;

在源码中: train_and_val.py文件是最终要执行的文件,它定了训练和预测的过程; input_data.py是将上一步中生成的TFRecord文件组织成batch的过程; VGG.py定义了VGG-16的网络结构; tool.py是最底层,定义了一些卷积池化等操作。

训练模型

train_and_val.py文件修改:

代码语言:javascript
复制
if __name__=="__main__":
    train()
    #evaluate()

根据自己的路径修改:

代码语言:javascript
复制
#初始权重路径
pre_trained_weights = 'vgg16_pretrain/vgg16.npy'
#训练数据路径
train_data_dir = 'F:\\flowersdata\\tfrecord\\train\\traindata.tfrecords*'
    test_data_dir = 
#预测数据路径
'F:\\flowersdata\\tfrecord\\test\\testdata.tfrecords*'
#训练生成文件路径
train_log_dir = 'logs/train/'
#预测生成文件路径
val_log_dir = 'logs/val/'

根据自己的显存容量修改:

代码语言:javascript
复制
IMG_W = 224
IMG_H = 224
BATCH_SIZE = 8

训练过程每50个step打印loss; 每200个step计算一个batch中的准确率; 每1000个step保存一次权重。

预测

train_and_val.py文件修改:

代码语言:javascript
复制
if __name__=="__main__":
    #train()
    evaluate()
代码语言:javascript
复制
#训练过程中生成的权重
log_dir = 'logs/train/'
#预测数据集路径
test_data_dir = 'F:\\flowersdata\\tfrecord\\test\\testdata.tfrecords*'
#用于生成tf文件的图片数量
n_test = 502

打印测试样本总数; 打印正确预测的样本总数; 打印top_1。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年05月27日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 数据准备
    • 下载数据和转换代码
      • 组织图片数据
        • 生成训练TFrecord
          • 生成预测TFrecord
          • 训练模型
            • 初始权重与源码下载
              • 训练模型
              • 预测
              相关产品与服务
              批量计算
              批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档