用 Keras 搭建 GAN:图像去模糊中的应用(附代码)

本文为雷锋字幕组编译的技术博客,原标题GAN with Keras: Application to Image Deblurring,作者为Raphaël Meudec。 翻译 | 廖颖 陈俊雅 整理 | 凡江

2014年 Ian Goodfellow 提出了生成对抗网络(GAN)。这篇文章主要介绍在Keras中搭建GAN实现图像去模糊。所有的Keras代码可点击这里。

可点击查看原始出版文章和Pytorch实现。

快速回忆生成对抗网络

GAN中两个网络的训练相互竞争。生成器( generator) 合成具有说服力的假输入来误导判别器(discriminator ),而判别器则是来识别这个输入是真的还是假的

生成对抗网络训练过程— 来源

训练过程主要有三步

  • 根据噪声,生成器合成假的输入
  • 用真的输入和假的输入共同训练判别器
  • 训练整个模型:整个模型中判别器与生成器连接

注意:在第三步中,判别器的权重是固定的

将这两个网络连接起来是由于生成器的输出没有可用的反馈。我们唯一的准则就是看判别器是否接受生成器的合成的例子。

这些只是对生成对抗网络的一个简单回顾,如果还是不够明白的话,可以参考完整介绍。

数据

Ian Goodfellow首次使用GAN模型是生成MNIST数据。 而本篇文章是使用生成对抗网络进行图像去模糊。因此生成器的输入不是噪声,而是模糊图像。

数据集来自GOPRO数据,你可以下载精简版数据集(https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view?usp=sharing),也可以下载完整版数据集(https://drive.google.com/file/d/1SlURvdQsokgsoyTosAaELc4zRjQz9T2U/view?usp=sharing)。其中包含了来自不同街道视角的人造模糊图像,根据不同的场景将数据集分在各个子文件夹中。

我们先把图像分到 A(模糊)和 B(清晰)两个文件夹。这个 A&B 结构对应于原始文章pix2pix 。我创建了一个 自定义脚本来执行这个任务。 看看 README 后尝试一下吧。

模型

训练过程还是一样,首先来看一下神经网络结构。

生成器

生成器要生成清晰图像,网络是基于ResNet blocks的,它可以记录对原始模糊图像操作的过程。原文还使用了基于UNet的版本,但我目前还没有实现。这两种结构都可以很好地进行图像去模糊。

DeblurGAN 生成器网络 结构 —  来源

核心是采用9个ResNet blocks对原始图像进行上采样。来看一下Keras上的实现!

ResNet 层就是一个基本的卷积层,其中,输入和输出相加,形成最终输出。

生成器结构的 Keras 实现

按照计划,用9个ResNet blocks对输入进行上采样。我们在输入到输出增加一个连接,然后除以2 来对输出进行归一化。

这就是生成器了! 我们再来看看判别器的结构吧。

判别器

判别器的目标就是要确定一张输入图片是否为合成的。因此判别器的结构采用卷积结构,而且是一个单值输出

判别器结构的 Keras 实现

最后一步就是建立完整的模型。这个GAN的一个特点就是输入的是真实图片而不是噪声 。因此我们就有了一个对生成器输出的直接反馈

接下来看看采用两个损失如何充分利用这个特殊性。

训练

损失

我们提取生成器最后和整个模型最后的损失。

第一个是感知损失,根据生成器输出直接可以计算得到。第一个损失保证 GAN 模型针对的是去模糊任务。它比较了VGG第一次卷积的输出

第二个损失是对整个模型输出计算的 Wasserstein loss,计算了两张图像的平均差值。众所周知,这种损失可以提高生成对抗网络的收敛性。

训练流程

第一步是加载数据并初始化模型。我们使用自定义函数加载数据集,然后对模型使用 Adam 优化器。我们设置 Keras 可训练选项来防止判别器进行训练。

然后我们进行epochs(一个完整的数据集通过了神经网络一次并且返回了一次的过程,称为一个epoch),并将整个数据集分批次(batches)。

最后根据两者的损失,可以相继训练判别器和生成器。用生成器生成假的输入,训练判别器区别真假输入,并对整个模型进行训练。

你可以参考Github来查看完整的循环。

实验

我使用的是在AWS 实例(p2.xlarge)上配置深度学习 AMI (version 3.0)进行的 。对GOPRO 精简版数据集的训练时间大约有 5 个小时(50个epochs)。

图像去模糊结果

从左到右:原始图像,模糊图像,GAN 输出

上面的输出结果都是我们用 Keras 进行 Deblur GAN 的结果。即使是对高度模糊,网络也可以减小模糊,产生一张具有更多信息的图片,使得车灯更加汇聚,树枝更加清晰。

左图: GOPRO 测试图像,右图:GAN 输出结果

因为引入了 VGG 来计算损失,所以会产生图像顶部出现感应特征的局限。

左图: GOPRO 测试图像,右图:GAN 输出结果

希望你们可以喜欢这篇关于生成对抗网络用于图像去模糊的文章。 你可以评论,关注我或者联系我。

如果你对机器视觉感兴趣,我们还写过一篇用Keras实现基于内容的图像复原 。下面是生成对抗网络资源的列表。

左图: GOPRO 测试图像,右图:GAN 输出结果

生成对抗网络资源

NIPS 2016: Generative Adversarial Networks by Ian Goodfellow

ICCV 2017: Tutorials on GAN

GAN Implementations with Keras by Eric Linder-Noren

A List of Generative Adversarial Networks Resources by deeplearning4j

Really-awesome-gan by Holger Caesar

博客原址 https://blog.sicara.com/keras-generative-adversarial-networks-image-deblurring-45e3ab6977b5


原文发布于微信公众号 - AI研习社(okweiwu)

原文发表时间:2018-04-06

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器学习、深度学习

人脸检测--Faceness-Net: Face Detection through Deep Facial Part Responses

Faceness-Net: Face Detection through Deep Facial Part Responses PAMI2017 From...

3626
来自专栏云时之间

深度学习与TensorFlow:FCN论文翻译(二)

Each layer of data in a convnet is a three-dimensional array of size h × w × d, ...

2252
来自专栏算法channel

机器学习逻辑回归:算法兑现为python代码

0 回顾 昨天推送了逻辑回归的基本原理:从逻辑回归的目标任务,到二分类模型的构建,再到如何用梯度下降求出二分类模型的权重参数。今天,我们将对这个算法兑现为代码...

3615
来自专栏AI科技评论

开发 | Keras版faster-rcnn算法详解(RPN计算)

AI科技评论按:本文首发于知乎专栏Learning Machine,作者张潇捷, AI科技评论获其授权转载。 前段时间学完Udacity的机器学习和深度学习的课...

74311
来自专栏数据小魔方

机器学习笔记——特征标准化

数据标准化是为了消除不同指标量纲的影响,方便指标之间的可比性,量纲差异会影响某些模型中距离计算的结果。

1223
来自专栏AI科技大本营的专栏

深度学习目标检测指南:如何过滤不感兴趣的分类及添加新分类?

AI 科技大本营按:本文编译自 Adrian Rosebrock 发表在 PyImageSearch 上的一篇博文。该博文缘起于一位网友向原作者请教的两个关于目...

1362
来自专栏AI研习社

用GAN来做图像生成,这是最好的方法

前言 对于图像问题,卷积神经网络相比于简单地全连接的神经网络更具优势。 本文将继续深入 GAN,通过融合卷积神经网络来对我们的 GAN 进行改进,实现一个深...

3374
来自专栏Spark学习技巧

【深度学习】②--细说卷积神经网络

1. 神经网络与卷积神经网络 先来回忆一下神经网络的结构,如下图,由输入层,输出层,隐藏层组成。每一个节点之间都是全连接,即上一层的节点会链接到下一层的每一个节...

4478
来自专栏杨熹的专栏

图解RNN

参考视频 RNN-Recurrent Neural Networks ---- 本文结构: 什么是 Recurrent Neural Networks ? R...

3575
来自专栏AI科技评论

开发 | 用GAN来做图像生成,这是最好的方法

前言 在我们之前的文章中,我们学习了如何构造一个简单的 GAN 来生成 MNIST 手写图片。对于图像问题,卷积神经网络相比于简单地全连接的神经网络更具优势,因...

3905

扫码关注云+社区

领取腾讯云代金券