开发 | 用GAN来做图像生成,这是最好的方法

前言

在我们之前的文章中,我们学习了如何构造一个简单的 GAN 来生成 MNIST 手写图片。对于图像问题,卷积神经网络相比于简单地全连接的神经网络更具优势,因此,我们这一节我们将继续深入 GAN,通过融合卷积神经网络来对我们的 GAN 进行改进,实现一个深度卷积 GAN。如果还没有亲手实践过 GAN 的小伙伴可以先去学习一下上一篇专栏:生成对抗网络(GAN)之 MNIST 数据生成。

专栏中的所有代码都在我的 GitHub中,欢迎 star 与 fork。

本次代码在 NELSONZHAO/zhihu/dcgan,里面包含了两个文件:

  • dcgan_mnist:基于 MNIST 手写数据集构造深度卷积 GAN 模型
  • dcgan_cifar:基于 CIFAR 数据集构造深度卷积 GAN 模型

本文主要以 MNIST 为例进行介绍,两者在本质上没有差别,只在细微的参数上有所调整。由于穷学生资源有限,没有对模型增加迭代次数,也没有构造更深的模型。并且也没有选取像素很高的图像,高像素非常消耗计算量。本节只是一个抛砖引玉的作用,让大家了解 DCGAN 的结构,如果有资源的小伙伴可以自己去尝试其他更清晰的图片以及更深的结构,相信会取得很不错的结果。

工具

  • Python3
  • TensorFlow 1.0
  • Jupyter notebook

正文

整个正文部分将包括以下部分:

- 数据加载

- 模型输入

- Generator

- Discriminator

- Loss

- Optimizer

- 训练模型

- 可视化

数据加载

数据加载部分采用 TensorFlow 中的 input_data 接口来进行加载。关于加载细节在前面的文章中已经写了很多次啦,相信看过我文章的小伙伴对 MNIST 加载也非常熟悉,这里不再赘述。

模型输入

在 GAN 中,我们的输入包括两部分,一个是真实图片,它将直接输入给 discriminator 来获得一个判别结果;另一个是随机噪声,随机噪声将作为 generator 来生成图片的材料,generator 再将生成图片传递给 discriminator 获得一个判别结果。

上面的函数定义了输入图片与噪声图片两个 tensor。

Generator

生成器接收一个噪声信号,基于该信号生成一个图片输入给判别器。在上一篇专栏文章生成对抗网络(GAN)之 MNIST 数据生成中,我们的生成器是一个全连接层的神经网络,而本节我们将生成器改造为包含卷积结构的网络,使其更加适合处理图片输入。整个生成器结构如下:

我们采用了 transposed convolution 将我们的噪声图片转换为了一个与输入图片具有相同 shape 的生成图像。我们来看一下具体的实现代码:

上面的代码是整个生成器的实现细节,里面包含了一些 trick,我们来一步步地看一下。

首先我们通过一个全连接层将输入的噪声图像转换成了一个 1 x 4*4*512 的结构,再将其 reshape 成一个 [batch_size, 4, 4, 512] 的形状,至此我们其实完成了第一步的转换。接下来我们使用了一个对加速收敛及提高卷积神经网络性能中非常有效的方法——加入 BN(batch normalization),它的思想是归一化当前层输入,使它们的均值为 0 和方差为 1,类似于我们归一化网络输入的方法。它的好处在于可以加速收敛,并且加入 BN 的卷积神经网络受权重初始化影响非常小,具有非常好的稳定性,对于提升卷积性能有很好的效果。关于 batch normalization,我会在后面专栏中进行一个详细的介绍。

完成 BN 后,我们使用 Leaky ReLU 作为激活函数,在上一篇专栏中我们已经提过这个函数,这里不再赘述。最后加入 dropout 正则化。剩下的 transposed convolution 结构层与之类似,只不过在最后一层中,我们不采用 BN,直接采用 tanh 激活函数输出生成的图片。

在上面的 transposed convolution 中,很多小伙伴肯定会对每一层 size 的变化疑惑,在这里来讲一下在 TensorFlow 中如何来计算每一层 feature map 的 size。首先,在卷积神经网络中,假如我们使用一个 k x k 的 filter 对 m x m x d 的图片进行卷积操作,strides 为 s,在 TensorFlow 中,当我们设置 padding='same'时,卷积以后的每一个 feature map 的 height 和 width 为

;当设置 padding='valid'时,每一个 feature map 的 height 和 width 为

。那么反过来,如果我们想要进行 transposed convolution 操作,比如将 7 x 7 的形状变为 14 x 14,那么此时,我们可以设置 padding='same',strides=2 即可,与 filter 的 size 没有关系;而如果将 4 x 4 变为 7 x 7 的话,当设置 padding='valid'时,即

,此时 s=1,k=4 即可实现我们的目标。

上面的代码中我也标注了每一步 shape 的变化。

Discriminator

Discriminator 接收一个图片,输出一个判别结果(概率)。其实 Discriminator 完全可以看做一个包含卷积神经网络的图片二分类器。结构如下:

实现代码如下:

上面代码其实就是一个简单的卷积神经网络图像识别问题,最终返回 logits(用来计算 loss)与 outputs。这里没有加入池化层的原因在于图片本身经过多层卷积以后已经非常小了,并且我们加入了 batch normalization 加速了训练,并不需要通过 max pooling 来进行特征提取加速训练。

Loss Function

Loss 部分分别计算 Generator 的 loss 与 Discriminator 的 loss,和之前一样,我们加入 label smoothing 防止过拟合,增强泛化能力。

Optimizer

GAN 中实际包含了两个神经网络,因此对于这两个神经网络要分开进行优化。代码如下:

这里的 Optimizer 和我们之前不同,由于我们使用了 TensorFlow 中的 batch normalization 函数,这个函数中有很多 trick 要注意。首先我们要知道,batch normalization 在训练阶段与非训练阶段的计算方式是有差别的,这也是为什么我们在使用 batch normalization 过程中需要指定 training 这个参数。上面使用 tf.control_dependencies 是为了保证在训练阶段能够一直更新 moving averages。具体参考 A Gentle Guide to Using Batch Normalization in Tensorflow - Rui Shu。

训练

到此为止,我们就完成了深度卷积 GAN 的构造,接着我们可以对我们的 GAN 来进行训练,并且定义一些辅助函数来可视化迭代的结果。代码太长就不放上来了,可以直接去我的 GitHub 下载。

我这里只设置了 5 轮 epochs,每隔 100 个 batch 打印一次结果,每一行代表同一个 epoch 下的 25 张图:

我们可以看出仅仅经过了少部分的迭代就已经生成非常清晰的手写数字,并且训练速度是非常快的。

上面的图是最后几次迭代的结果。我们可以回顾一下上一篇的一个简单的全连接层的 GAN,收敛速度明显不如深度卷积 GAN。

总结

到此为止,我们学习了一个深度卷积 GAN,并且看到相比于之前简单的 GAN 来说,深度卷积 GAN 的性能更加优秀。当然除了 MNST 数据集以外,小伙伴儿们还可以尝试很多其他图片,比如我们之前用到过的 CIFAR 数据集,我在这里也实现了一个 CIFAR 数据集的图片生成,我只选取了马的图片进行训练:

刚开始训练时:

训练 50 个 epochs:

这里我只设置了 50 次迭代,可以看到最后已经生成了非常明显的马的图像,可见深度卷积 GAN 的优势。

我的 GitHub:NELSONZHAO (Nelson Zhao)

上面包含了我的专栏中所有的代码实现,欢迎 star,欢迎 fork。

原文发布于微信公众号 - AI科技评论(aitechtalk)

原文发表时间:2017-08-09

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏java闲聊

JDK1.8 ArrayList 源码解析

当运行 ArrayList<Integer> list = new ArrayList<>() ; ,因为它没有指定初始容量,所以它调用的是它的无参构造

1192
来自专栏计算机视觉与深度学习基础

Leetcode 114 Flatten Binary Tree to Linked List

Given a binary tree, flatten it to a linked list in-place. For example, Given...

1938
来自专栏Hongten

ArrayList VS Vector(ArrayList和Vector的区别)_面试的时候经常出现

1692
来自专栏ml

朴素贝叶斯分类器(离散型)算法实现(一)

1. 贝叶斯定理:        (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A)   由(1)得    P(A|B) = P(B|...

3437
来自专栏Java Edge

AbstractList源码解析1 实现的方法2 两种内部迭代器3 两种内部类3 SubList 源码分析4 RandomAccessSubList 源码:AbstractList 作为 Lis

它实现了 List 的一些位置相关操作(比如 get,set,add,remove),是第一个实现随机访问方法的集合类,但不支持添加和替换

422
来自专栏刘君君

JDK8的HashMap源码学习笔记

3008
来自专栏后端之路

LinkedList源码解读

List中除了ArrayList我们最常用的就是LinkedList了。 LInkedList与ArrayList的最大区别在于元素的插入效率和随机访问效率 ...

19510
来自专栏xingoo, 一个梦想做发明家的程序员

AOE关键路径

这个算法来求关键路径,其实就是利用拓扑排序,首先求出,每个节点最晚开始时间,再倒退求每个最早开始的时间。 从而算出活动最早开始的时间和最晚开始的时间,如果这两个...

2507
来自专栏开发与安全

算法:AOV网(Activity on Vextex Network)与拓扑排序

在一个表示工程的有向图中,用顶点表示活动,用弧表示活动之间的优先关系,这样的有向图为顶点表示活动的网,我们称之为AOV网(Activity on Vextex ...

2517
来自专栏xingoo, 一个梦想做发明家的程序员

Spark踩坑——java.lang.AbstractMethodError

百度了一下说是版本不一致导致的。于是重新检查各个jar包,发现spark-sql-kafka的版本是2.2,而spark的版本是2.3,修改spark-sql-...

1200

扫码关注云+社区