前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习样本生成data augmentation

深度学习样本生成data augmentation

作者头像
bear_fish
发布2018-09-14 10:05:06
1.1K0
发布2018-09-14 10:05:06
举报

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1338384

在做深度学习图片分类的时候,很多是有些样本不足这个时候我们就会自己生成样本,如opencv对图片旋转,扭曲等等操作。google了一下deep learning data augmentation 发现了github几种开源的的方法主要是使用opencv结合python的PIL库。最终发现Augmentor好用

本文内容如下:



先上几张生成的图片看下效果:

原始图片

旋转生成:

Augmentor 生成

下面贴出代码,应该比较好懂,Augmentor使用的话看链接主要是使用pipeline对图片以一定的概率做变换。

代码语言:javascript
复制
# _*_ coding:utf-8 _*_

"""
Deep learning image augmentation
cited from https://scottontechnology.com/flip-image-opencv-python/
http://augmentor.readthedocs.io/en/master/userguide/mainfeatures.html
"""

import cv2
import glob
import random
import os
from multiprocessing import Pool as ProcessPool
from multiprocessing.dummy import Pool as ThreadPool
import Augmentor
import path_var


def img_flip():
    path = "F:/ad_samples/download_sample/14/8DB54D749B1D4A2D5FD3441C681D9A2C522453AC_s.jpg"
    img = cv2.imread(path)

    horizontal_img = img.copy()
    vertical_img = img.copy()
    both_img = img.copy()

    horizontal_img = cv2.flip(img, 0)
    vertical_img = cv2.flip(img, 1)
    both_img = cv2.flip(img, -1)

    cv2.imshow("original img", img)
    cv2.imshow("horizontal img", horizontal_img)
    cv2.imshow("vertical img", vertical_img)
    cv2.imshow("both flip", both_img)

    cv2.waitKey(0)
    cv2.destroyAllWindows()

def flip_img_save2dir(file):
    img = cv2.imread(file)

    dst_dir = path_var.g_dst_dir

    h_img = img.copy()
    v_img = img.copy()
    b_img = img.copy()

    h_img = cv2.flip(img, 0)
    v_img = cv2.flip(img, 1)
    b_img = cv2.flip(img, -1)

    # file like F:/ad_samples/train_samples/ad_text_artifact/base_type/type_10.jpg
    # get file name "type_10"
    # type_10.jpg
    base_name = os.path.basename(file)
    # type_10
    base_name = os.path.splitext(base_name)[0]

    file_name = dst_dir + base_name + "_h" + ".jpg"
    cv2.imwrite(file_name, h_img)

    file_name = dst_dir + base_name + "_v" + ".jpg"
    cv2.imwrite(file_name, v_img)

    file_name = dst_dir + base_name + '_b' + ".jpg"
    cv2.imwrite(file_name, b_img)


def do_all_flip(base_dir="F:/ad_samples/train_samples/ad_web_2/"):
    """
    flip all the images in dir, and then save them
     to another dir
    :return:
    """
    # get all files
    files = glob.glob(base_dir + "/*.png")
    # like ['E:/img\\1.jpg', 'E:/img\\10.jpg']

    # start 3 process
    # pool = ProcessPool(3)
    pool = ThreadPool(3)
    rets = pool.map(flip_img_save2dir, files)
    pool.close()
    pool.join()
    print 'all images accomplish flip and save to dir'


def flip_all_in_dir():
    base_dir = 'F:/ad_samples/train_samples/others/'
    sub_dir_lst = glob.glob(base_dir + "*")
    # ['F:/dir1', 'F:/dir2']

    # print sub_dir_lst
    new_sub_dir = [os.path.join(base_dir, item + '_flip/') for item in os.listdir(base_dir)]
    # ['F:/dir1_flip', 'F:/dir2_flip']

    for dir_item, new_item in zip(sub_dir_lst[10:], new_sub_dir[10:]):
        global g_dst_dir
        if not os.path.exists(new_item):
            os.makedirs(new_item)
        # g_dst_dir = new_item
        # Path.g_dst_dir = new_item
        path_var.g_dst_dir = new_item
        print 'flip %s, flip dir %s' % (dir_item, new_item)
        do_all_flip(base_dir=dir_item)



def augmentation():
    path = 'F:/augment'
    # path = 'F:/ad_samples/train_samples/ad_text'

    # output_path = 'F:/ad_samples/train_samples/ad_text_artifact/augmentation'
    output_path = 'output'

    p = Augmentor.Pipeline(path, output_directory=output_path)

    p.zoom(probability=0.1, min_factor=1.1, max_factor=1.3)
    p.flip_left_right(probability=0.1)
    p.rotate(probability=0.2, max_left_rotation=15, max_right_rotation=16)
    p.shear(probability=0.2, max_shear_left=10, max_shear_right=10)
    p.skew(probability=0.1, magnitude=0.6)
    p.skew_tilt(probability=0.2, magnitude=0.6)
    p.random_distortion(probability=0.3, grid_height=4, grid_width=4, magnitude=4)

    # p.random_distortion(probability=0.2, grid_height=4, grid_width=4, magnitude=4)
    # p.rotate90(probability=1)
    # SIZE = 4164 * 4
    SIZE = 5 * 4
    p.sample(SIZE)


if __name__ == '__main__':
    # img_flip()
    # flip_all_in_dir()
    # do_all_flip()
    augmentation()
    # test single image flip and save
    # file = 'F:/ad_samples/train_samples/ad_text_artifact/base_type/type_10.jpg'
    # flip_img_save2dir(file=file)
    pass

工作中使用的语言比较多写过C++,Java, 部分html+js, python的.由于用到语言的间歇性,比如还几个月没有使用python了许多技巧就忘记了,于是我把一些常用的python代码分类项目在本人的github中,当实际中用到某一方法的时候就把常用的方法放到一个文件中方便查询。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年08月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档