前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >图像训练样本量少时的数据增强技术

图像训练样本量少时的数据增强技术

作者头像
Cloudox
发布2021-11-23 16:53:06
1.3K0
发布2021-11-23 16:53:06
举报
文章被收录于专栏:月亮与二进制月亮与二进制

在深度学习训练过程中,训练数据是很重要的,在样本量方便,一是要有尽量多的训练样本,二是要保证样本的分布够均匀,也就是各个类别下的样本量都要足够,不能有的很多,有的特别少。但是实际采集数据的过程中,可能经常会遇到样本量不够的情况,这就很容易导致训练出的模型过拟合,泛化能力不足,这时候该怎么办呢?

一种方法是利用预训练好的模型,也就是使用另一个在大量样本下获得足够训练的模型,只要这个模型的训练数据集足够大,而且够通用,那么可以理解为其学到的特征空间层次结构能够有效地作为视觉世界的通用模型基础。就好像你如果看过很多东西了,那么也就拥有了对于世界万物的一个基本判断能力,在此基础上可以进一步做集中突破训练来学会判断特殊的物体,这时候需要的样本量就不需要那么多了。比如说,如果已经在ImageNet下训练了一个网络,可以识别动物及日常用品等,这时候你需要得到一个能够区分猫狗的模型,那么在其基础上进行训练是很有效的,比你单纯在小样本的猫狗图像上重头做训练效果要好。当然,在实际操作中,我们需要保留网络除了分类器部分的前置层(卷积基)及其权重不变,只训练我们新的分类器,这也很好理解,毕竟要利用它的基础嘛。

但本文要讲的不是这个方法,而是另一种思路,即强行增加训练样本数量,生生在已有的样本下再造出一批来,这叫做数据增强

所谓数据增强,就是从已有的图像样本中生造出更多的样本数据,这些图像怎么来呢?方法是使用一些方法,来随机变换生成一些可信图像,这些通过随机变换生成的图像,要保证从逻辑上不会给模型辨认带来困扰,也就是从分类的角度应该依然属于其原本图像同一类,但是又要与原本的图像有一些区别,这样模型在训练时就不会两次看到完全相同的图像,这样就能够观察到更多的内容,也就提升了泛化能力。

产生新图像的随机变换方法大致包括:

  • 随机旋转一些角度
  • 水平横移一定距离
  • 竖直横移一定距离
  • 随机缩放一定范围
  • 进行水平翻转
  • 进行竖直翻转
  • 等等

这些变换方式都是可以考虑的,同时这些变换的组合也是可以的,但是要注意不能产生逻辑上的问题。比如你要训练一个分类猫狗的模型,那可以对某些猫狗图像进行旋转角度、横移、水平翻转,但是竖直翻转可能就不太合适了。又比如你要训练一个文字识别模型,那么可以进行随机缩放、横移,但是无论水平翻转还是竖直翻转可能都不太合适。

通过这些变换及其组合,我们就能得到很多同等类别下又有所区别的图像了。

那怎么实现呢?

当然,最简单的可以自己写代码来加入这些随机扰动,但Keras有更方便的函数来直接进行操作。

ImageDataGenerator是keras.preprocessing.image包下的一个类,可以设置图像的这些随机扰动来生成新的图像数据,简单的代码如下所示:

代码语言:javascript
复制
# -- coding: utf-8 --
import numpy as np
from keras.preprocessing import Image
from keras.preprocessing.image import ImageDataGenerator
import cv2 as cv
import os

img = cv.imread("./photo.png")
img = Image.img_to_array(img)
img = img.reshape((1,) + img.shape)
datagen = ImageDataGenerator(width_shift_range=0.2, height_shift_range=0.2, fill_mode='wrap')
i = 0
for batch in datagen.flow(img, batch_size=1):
    img_name = "./image_%d.jpg" % (i)
    cv.imwrite(img_name, batch[0])
    i += 1
    if i % 4 == 0:
        break

上面代码所实现的就是将一张图像进行随机变换,我设置的变换形式只有在水平和竖直方向进行横移,且横移的范围最多占整个宽、高的20%,另外对于横移空出来的区域,填充方式为“wrap”,这是什么意思待会再解释。

设置好变换方式后,就可以通过datagen.flow来生成数据了,传入的参数包括图像和处理数量,我们这就处理一张图。在循环中这个类会不断地随机组合变换来生成新图像,我们把生成的新图像保存下来,并且设置只生成四张就停,这里比如设置停止条件,否则它会一直生成下去的。

现在我们来具体说一说ImageDataGenerator包含哪些变换方式,从Keras中文手册中我们能看到它包含这些参数:

  • featurewise_center:布尔值,使输入数据集去中心化(均值为0), 按feature执行
  • samplewise_center:布尔值,使输入数据的每个样本均值为0
  • featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行
  • samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差
  • zca_whitening:布尔值,对输入数据施加ZCA白化
  • zca_epsilon: ZCA使用的eposilon,默认1e-6
  • rotation_range:整数,数据提升时图片随机转动的角度
  • width_shift_range:浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
  • height_shift_range:浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度
  • shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)
  • zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]
  • channel_shift_range:浮点数,随机通道偏移的幅度
  • fill_mode:;‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理:
    • 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
    • 'nearest': aaaaaaaa|abcd|dddddddd
    • 'reflect': abcddcba|abcd|dcbaabcd
    • 'wrap': abcdabcd|abcd|abcdabcd
  • cval:浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值
  • horizontal_flip:布尔值,进行随机水平翻转
  • vertical_flip:布尔值,进行随机竖直翻转
  • rescale: 重放缩因子,默认为None. 如果为None或0则不进行放缩,否则会将该数值乘到数据上(在应用其他变换之前)
  • preprocessing_function: 将被应用于每个输入的函数。该函数将在图片缩放和数据提升之后运行。该函数接受一个参数,为一张图片(秩为3的numpy array),并且输出一个具有相同shape的numpy array
  • data_format:字符串,“channel_first”或“channel_last”之一,代表图像的通道维的位置。该参数是Keras 1.x中的image_dim_ordering,“channel_last”对应原本的“tf”,“channel_first”对应原本的“th”。以128x128的RGB图像为例,“channel_first”应将数据组织为(3,128,128),而“channel_last”应将数据组织为(128,128,3)。该参数的默认值是~/.keras/keras.json中设置的值,若从未设置过,则为“channel_last”

比如我对这张图像使用上面的代码处理:

原图
原图

那么会得到四张经过处理的图:

1
1
2
2
3
3
4
4

可以看到,对同一张图,就得到了四张新的变换后的图,仔细看会发现,这些变换是会组合的。可以看到对于这种图,竖直移动更是几乎看不出操作痕迹。

这样样本量就翻了四倍啦,经过实验,在猫狗分类模型上能将精度从72%提升到82%,提升的效果还是非常明显的,所以有这方面困扰的可以试一下。


本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018/10/26 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
文字识别
文字识别(Optical Character Recognition,OCR)基于腾讯优图实验室的深度学习技术,将图片上的文字内容,智能识别成为可编辑的文本。OCR 支持身份证、名片等卡证类和票据类的印刷体识别,也支持运单等手写体识别,支持提供定制化服务,可以有效地代替人工录入信息。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档