数据预处理-对图片扩展的处理方法

Keras非常便捷的提供了图片预处理的类--ImageDataGenerator 可以用这个图片生成器生成一个batch周期内的数据,它支持实时的数据扩展,训练的时候会无限生成数据,一直到达设定的epoch次数才停止。

  • ImageDataGeneoator()常用参数:
    • rotation_range:整数,数据扩展时图片随机转动的角度
    • width_shift_range:浮点数,图片宽度的某个比例,数据扩展时图片水 平偏移的幅度
    • height_shift_range:浮点数,图片高度的某个比例,数据扩展时图片 竖直偏移的幅度
    • shear_range:浮点数,剪切强度(按逆时针方向的剪切变换角度)
    • zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度。若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]
    • horizontal_flip:布尔值,进行随机水平翻转
    • vertical_flip:布尔值,进行随机竖直翻转
    • rescale: 重放缩因子,默认为None. 如果为None或0则不进行放缩,否则会将该数值乘到数据上(在应用其他变换之前)
    • fill_mode:‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理
  • ImageDataGeneoator()的方法:
    • fit():计算依赖于数据的变换所需要的统计信息(均值方差等),只有使用featurewise_center,featurewise_std_normalization或zca_whitening参数时需要此函数。
    • flow(): 接收numpy数组和标签为参数,生成经过数据扩展或标准化后的batch数据,并在一个无限循环中不断的返回数据
    • flow_from_directory() :以文件夹路径为参数,生成经过数据提升/归一化后的数据,在一个无限循环中无限产生数据
  • 图片数据扩展举例: 在数据集不够多的情况下,可以使用ImageDataGenerator()来扩大数据集防止搭建的网络出现过拟合现象。
from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from glob import glob
import matplotlib.pyplot as plt      
import cv2
import numpy as np

# 图片生成器
datagen = ImageDataGenerator(
            rotation_range=40,
            width_shift_range=0.2,
            height_shift_range=0.2,
            rescale=1./255,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest')

# 打印转换前的图片
img = load_img('dog-project-master/kaggle/train/cat/cat.4.jpg')
plt.imshow(img)
plt.title('Before:')
plt.show()

# 将图片转换为数组,并重新设定形状
x = img_to_array(img) 
x = x.reshape((1,) + x.shape) 
# x的形状重组为(1,width,height,channels),第一个参数为batch_size

# 这里人工设置停止生成, 并保存图片用于可视化
i = 0
for batch in datagen.flow(x,batch_size=1,save_to_dir='E://udacity-dl/transform_cat',save_prefix='cat',save_format='png'):
    i +=1
    if i > 3 :
        break  # 如果不设置,生成器会不断的生成图片

# 打印转换后的图片
cat_images =  np.array(glob("transform_cat/*"))
cat_image = cv2.imread(cat_images[1])
plt.imshow(cat_image)
plt.title("After1")
plt.show()
cat_image2 = cv2.imread(cat_images[2])
plt.imshow(cat_image2)
plt.title("After2")
plt.show()
cat_image3 = cv2.imread(cat_images[3])
plt.imshow(cat_image3)
plt.title("After3")
plt.show()
  • 实时生成数据集用于训练 经常会和函数式模型Model()的API---.fit_generator()配合使用,在训练中实时地生成数据。

1)和.flow()配合使用

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(x_train)

# fits the model on batches with real-time data augmentation:
model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
                    steps_per_epoch=len(x_train), epochs=epochs)

# here's a more "manual" example
for e in range(epochs):
    print 'Epoch', e
    batches = 0
    for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
        loss = model.train(x_batch, y_batch)
        batches += 1
        if batches >= len(x_train) / 32:
            # we need to break the loop by hand because
            # the generator loops indefinitely
            break

2)和.flow_from_directory() 配合使用

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)

参考资料:

  1. Keras中文文档--图片预处理
  2. glob文件路径查找模块
  3. Building powerful image classification models using very little data By Francois Chollet

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券