使用生成式对抗网络进行图像去模糊

AiTechYun

编辑:yuxiangyu

本文主要讨论使用生成式对抗网络实现图像去模糊。

代码:https://github.com/RaphaelMeudec/deblur-gan

生成对抗网络

在生成对抗网络中,两个网络进行对抗训练。生成器通过创建逼真的假输入来误导鉴别器。鉴别器鉴别输入是真实的还是伪造的。

GAN训练过程

训练主要分为3个步骤:

– – 使用生成器根据噪声创建假输入。 – 根据真的输入和假的输入训练鉴别器 – 训练整个模型:模型被构建成用鉴别器限制生成器。

注意鉴别器的权重在第三步中要进行冻结。

之所以链接两个网络,是因为对生成器的输出没有合适的反馈。我们唯一的衡量标准是鉴别器是否接受生成的样本。

数据

在本教程中,我们使用GAN进行图像去模糊。因此,生成器的输入不是噪声而是模糊的图像。

数据集是GOPRO数据集。您可以下载一个轻量版(9GB)或完整版(35GB)。它包含来自多个街景的模糊图像。数据集在子文件夹中按场景分类。

轻量版:https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view?usp=sharing

完整版:https://drive.google.com/file/d/1SlURvdQsokgsoyTosAaELc4zRjQz9T2U/view?usp=sharing

我们首先将图像分配到两个文件夹A(模糊)和B(清晰)。

模型

训练过程保持不变。首先,让我们看看神经网络架构!

生成器

生成器旨在重现清晰的图像。网络基于ResNet模块。它跟踪应用于原始模糊图像的演变。

DeblurGAN生成网络的结构

核心是用于对原始图像进行重新采样的9个ResNet模块。让我们看看Keras的实现。

from keras.layersimport Input, Conv2D, Activation, BatchNormalization
from keras.layers.mergeimport Add
from keras.layers.coreimport Dropout

def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
    """
    Instanciate a Keras Resnet Block using sequential API.
    :param input: Input tensor
    :param filters: Number of filters to use
    :param kernel_size: Shape of the kernel for the convolution
    :param strides: Shape of the strides for the convolution
    :param use_dropout: Boolean value to determine the use of dropout
    :return: Keras Model
    """
    x= ReflectionPadding2D((1,1))(input)
    x= Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=strides,)(x)
    x= BatchNormalization()(x)
    x= Activation('relu')(x)

    if use_dropout:
        x= Dropout(0.5)(x)

    x= ReflectionPadding2D((1,1))(x)
    x= Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,)(x)
    x= BatchNormalization()(x)

    # Two convolution layers followed by a direct connection between input and output
    merged= Add()([input, x])
    return merged

该ResNet层本质上是一个卷积层,通过添加输入和输出形成最终输出。

from keras.layersimport Input, Activation, Add
from keras.layers.advanced_activationsimport LeakyReLU
from keras.layers.convolutionalimport Conv2D, Conv2DTranspose
from keras.layers.coreimport Lambda
from keras.layers.normalizationimport BatchNormalization
from keras.modelsimport Model

from layer_utilsimport ReflectionPadding2D, res_block

ngf= 64
input_nc= 3
output_nc= 3
input_shape_generator= (256,256, input_nc)
n_blocks_gen= 9


def generator_model():
    """Build generator architecture."""
    # Current version : ResNet block
    inputs= Input(shape=image_shape)

    x= ReflectionPadding2D((3,3))(inputs)
    x= Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x)
    x= BatchNormalization()(x)
    x= Activation('relu')(x)

    # Increase filter number
    n_downsampling= 2
    for iin range(n_downsampling):
        mult= 2**i
        x= Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x)
        x= BatchNormalization()(x)
        x= Activation('relu')(x)

    # Apply 9 ResNet blocks
    mult= 2**n_downsampling
    for iin range(n_blocks_gen):
        x= res_block(x, ngf*mult, use_dropout=True)

    # Decrease filter number to 3 (RGB)
    for iin range(n_downsampling):
        mult= 2**(n_downsampling- i)
        x= Conv2DTranspose(filters=int(ngf* mult/ 2), kernel_size=(3,3), strides=2, padding='same')(x)
        x= BatchNormalization()(x)
        x= Activation('relu')(x)

    x= ReflectionPadding2D((3,3))(x)
    x= Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x)
    x= Activation('tanh')(x)

    # Add direct connection from input to output and recenter to [-1, 1]
    outputs= Add()([x, inputs])
    outputs= Lambda(lambda z: z/2)(outputs)

    model= Model(inputs=inputs, outputs=outputs, name='Generator')
    return model

Keras实现生成器的架构

按计划,9个ResNet模块应用于之前的输入采样版本。我们添加从输入到输出的连接,然后除以2以保持归一化输出。

这样生成器就完成了,让我们来看看鉴别器的架构。

鉴别器

鉴别器的目标是确定输入图像是否是伪造的。因此,鉴别器的架构是卷积的并输出单一值。

from keras.layersimport Input
from keras.layers.advanced_activationsimport LeakyReLU
from keras.layers.convolutionalimport Conv2D
from keras.layers.coreimport Dense, Flatten
from keras.layers.normalizationimport BatchNormalization
from keras.modelsimport Model

ndf= 64
output_nc= 3
input_shape_discriminator= (256,256, output_nc)


def discriminator_model():
    """Build discriminator architecture."""
    n_layers, use_sigmoid= 3,False
    inputs= Input(shape=input_shape_discriminator)

    x= Conv2D(filters=ndf, kernel_size=(4,4), strides=2, padding='same')(inputs)
    x= LeakyReLU(0.2)(x)

    nf_mult, nf_mult_prev= 1,1
    for nin range(n_layers):
        nf_mult_prev, nf_mult= nf_mult,min(2**n,8)
        x= Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=2, padding='same')(x)
        x= BatchNormalization()(x)
        x= LeakyReLU(0.2)(x)

    nf_mult_prev, nf_mult= nf_mult,min(2**n_layers,8)
    x= Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=1, padding='same')(x)
    x= BatchNormalization()(x)
    x= LeakyReLU(0.2)(x)

    x= Conv2D(filters=1, kernel_size=(4,4), strides=1, padding='same')(x)
    if use_sigmoid:
        x= Activation('sigmoid')(x)

    x= Flatten()(x)
    x= Dense(1024, activation='tanh')(x)
    x= Dense(1, activation='sigmoid')(x)

    model= Model(inputs=inputs, outputs=x, name='Discriminator')
    return model

使用Keras进行鉴别者架构的实现

最后一步是构建完整模型。这个GAN的一个特点是输入是真实的图像而不是噪音。因此,我们对生成机的输出有直接反馈。

from keras.layersimport Input
from keras.modelsimport Model

def generator_containing_discriminator_multiple_outputs(generator, discriminator):
    inputs= Input(shape=image_shape)
    generated_images= generator(inputs)
    outputs= discriminator(generated_images)
    model= Model(inputs=inputs, outputs=[generated_images, outputs])
    return model

让我们看看如何使用这两个损失来充分利用来充分利用这个特点。

训练

损失

我们分别在两处提取损失,在生成器的末尾和整个模型的末尾。

第一个是直接根据生成输出计算的感知损失。这个损失确保了GAN模型进行去模糊的任务。它比较VGG的第一个卷积的输出。

import keras.backend as K
from keras.applications.vgg16import VGG16
from keras.modelsimport Model

image_shape= (256,256,3)

def perceptual_loss(y_true, y_pred):
    vgg= VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
    loss_model= Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable= False
    return K.mean(K.square(loss_model(y_true)- loss_model(y_pred)))

第二个损失是对整个模型的输出执行的Wasserstein损失。它取两个图像之间的差异的均值。这可以改善生成对抗网络的收敛性。

import keras.backend as K

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)

训练

第一步,加载数据并初始化所有模型。我们使用我们的自定义函数来加载数据集,并为我们的模型添加Adam优化。我们设置Keras可训练选项,防止鉴别器进行训练。

# Load dataset
data= load_images('./images/train', n_images)
y_train, x_train= data['B'], data['A']

# Initialize models
g= generator_model()
d= discriminator_model()
d_on_g= generator_containing_discriminator_multiple_outputs(g, d)

# Initialize optimizers
g_opt= Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt= Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt= Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

# Compile models
d.trainable= True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable= False
loss= [perceptual_loss, wasserstein_loss]
loss_weights= [100,1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable= True

然后,我们开始开始训练,并将数据集分成几个批次。

for epochin range(epoch_num):
  print('epoch: {}/{}'.format(epoch, epoch_num))
  print('batches: {}'.format(x_train.shape[0]/ batch_size))

  # Randomize images into batches
  permutated_indexes= np.random.permutation(x_train.shape[0])

  for indexin range(int(x_train.shape[0]/ batch_size)):
      batch_indexes= permutated_indexes[index*batch_size:(index+1)*batch_size]
      image_blur_batch= x_train[batch_indexes]
      image_full_batch= y_train[batch_indexes]

最后,我们根据两种损失分别训练鉴别器和生成器。我们用生成器产生假输入,然后训练辨别器来区分输入的真假,最后我们训练整个模型。

for epochin range(epoch_num):
  for indexin range(batches):
    # [Batch Preparation]

    # Generate fake inputs
    generated_images= g.predict(x=image_blur_batch, batch_size=batch_size)

    # Train multiple times discriminator on real and fake inputs
    for _in range(critic_updates):
        d_loss_real= d.train_on_batch(image_full_batch, output_true_batch)
        d_loss_fake= d.train_on_batch(generated_images, output_false_batch)
        d_loss= 0.5 * np.add(d_loss_fake, d_loss_real)

    d.trainable= False
    # Train generator only on discriminator's decision and generated images
    d_on_g_loss= d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])

    d.trainable= True

Github:https://www.github.com/raphaelmeudec/deblur-gan

材料

我使用了AWS实例(p2.xlarge)和Deep Learning AMI(版本3.0)。使用GOPRO数据集,训练时间约为5小时(50个周期)。

图像去模糊结果

从左到右:原始图像,模糊图像,GAN输出

上图是我们Keras去模糊GAN的结果。即使在模糊很重的情况下,网络也能够减少模糊并生成令人信服的图像。我们能够看到车灯和树枝更清晰了。

左:GOPRO测试图像,右:GAN输出

我们能看到图像顶部的缺陷(条纹状),这可能是因为使用VGG作为损失引起的。

左:GOPRO测试图像,右:GAN输出

左:GOPRO测试图像,右:GAN输出

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2018-03-25

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏专知

【论文推荐】最新5篇目标检测相关论文——显著目标检测、弱监督One-Shot检测、多框检测器、携带物体检测、假彩色图像检测

【导读】专知内容组整理了最近目标检测相关文章,为大家进行介绍,欢迎查看! 1. MSDNN: Multi-Scale Deep Neural Network f...

49170
来自专栏机器学习和数学

[机智的机器在学习] TensorFlow实现Kmeans聚类

对于机器学习算法来说,主要分为有监督学习和无监督学习,前面有篇文章介绍过机器学习算法的分类,不知道的童鞋可以去看看。然后今天要讲的Kmeans算法属于无监督算法...

1.3K130
来自专栏智能算法

深度学习三人行(第3期)---- TensorFlow从DNN入手

21620

如何在Python中从零开始实现随机森林

决策树可能会受到高度变化的影响,使得结果对所使用的特定训练数据而言变得脆弱。

25180
来自专栏云时之间

深度学习与神经网络:单层感知机

今天这个文章让我们一起来学习下感知机: ? 一个传统的单层感知机如上图所示,其实理解起来很简单,我们可以直接理解为输入节点接受信号之后直接传输到输出节点,然后得...

49950
来自专栏机器学习算法与理论

几种距离的集中比较

提到检索的方法,比如KNN算法,这些都需要用到“距离”这个尺度去度量两者的近似程度。但是,距离也有很多种,除了我们熟悉的欧氏距离之外,其实还有很多。。。 余弦距...

36570
来自专栏深度学习那些事儿

浅谈深度学习:如何计算模型以及中间变量的显存占用大小

博客原文:https://oldpan.me/archives/how-to-calculate-gpu-memory

60480
来自专栏云时之间

深度学习与神经网络:单层感知机

一个传统的单层感知机如上图所示,其实理解起来很简单,我们可以直接理解为输入节点接受信号之后直接传输到输出节点,然后得到结果y.

38290
来自专栏奇点大数据

Pytorch神器(6)

作者介绍:高扬,奇点大数据创始人。技术畅销书《白话大数据与机器学习》、《白话深度学习与Tensorflow》、《数据科学家养成手册》著书人。重庆工商大学研究生导...

20630
来自专栏小小挖掘机

推荐系统遇上深度学习(二)--FFM模型理论和实践

推荐系统遇上深度学习系列: 推荐系统遇上深度学习(一)--FM模型理论和实践 1、FFM理论 在CTR预估中,经常会遇到one-hot类型的变量,one-ho...

1K40

扫码关注云+社区

领取腾讯云代金券