前言
生成式对抗网络(GAN),是14年GoodfellowIan在论文Generative Adversarial Nets中提出来的。Yann LeCun曾评价GAN是“20年来机器学习领域最酷的想法。本文以在图像领域表现效果较好的GAN衍生模型DCGAN为基础,带大家全面了解一下GAN的算法。
一、GAN算法原理
在介绍GAN算法之前需要先了解两个基本基本名称,生成模型和判别模型。
生成模型:就是基于根据给定观测数据,找出观测数据内部的统计规律,并且能够依据基于所得的概率分布模型,产生全新的,与观测数据类似的数据。
判别模型:与生成模型相对应,判别模型的作用就是通过学习数据的内部规律,识别出传入模型的数据是真实的观测数据,还是由生成模型生成的数据。
简单说来,就是一个是作假的,一个是鉴别真伪的。通过不断的训练,作假的生成模型生成的数据越来越像真的,以此同时,鉴别真伪的判别模型的鉴定能力也越来越强。通过不断大量数据的反复迭代训练,最终,生成模型生成的数据可以超过人类的判定能力,同时,判别模型的鉴别能力也将超过人类水平。
如上图所示,通过观测的图片信息,也就是X信息,Z为高斯噪声数据,利用生成器G可以从噪声数据生成为模拟的真实数据,然后利用判别器D进行判定。通过不断的迭代优化,就可以训练出能够生成以假乱真数据的生成器G,和能够有火眼金睛能力的鉴别器D。
二、DCGAN算法原理
随着深度学习的快速发展,以新一代深度学习技术为先导的图像识别技术极大的影响着计算机视觉技术发展,这其中CNN(卷积神经网络)技术是非常具有代表意义的一项技术,广泛应用于计算机视觉处理的各个领域,并取得了显著性的成就。
DCGAN全称 DeepConvolutional Generative Adversarial Networks,深度卷积生成对抗网络,通过结合CNN技术和GAN技术的优点,训练出一个可以生成图片的生成器,通过对测试Demo的反复迭代,最后生成的图片与真实图片之间的差距越来越小。
论文(UnsupervisedRepresentation Learning with Deep Convolutional Generative Adversarial Networks,https://arxiv.org/abs/1511.06434)中通过利用GAN的原理和CNN技术在图像技术上的优势,构建了表现良好的DCGAN架构。
具体来说,DCGAN算法的基本原理如上图所示:
1. 在生成器中,通过不断的reshape和反卷积操作,将高斯噪声数据,映射为64*64*3的数据维度,通过RGB的原理,可以得到一个64*64的图片。
2. 在判别器中,通过常规的卷积操作和线性操作,最终映射到一维向量,最后用sigmoid函数作为激活函数作为输出的激活函数。
3. 通过不断的迭代,当生成器和判别器的损失函数收敛后,即可得出稳定的生成器和判别器。
相对于其他CNN图像识别算法,DCGAN算法做了一些明显的修改,具体如下:
a. 取消了pooling层,在判别器中使用strided convolutions,在生成器中使用了部分strided convolutions.
b. 在生成器和判别器中都是用了batchnorm,主要可以防止每一层的梯度消失和爆炸的情况。
c. 在深度框架中,移除了全链接层。
d. 在生成器中,除了在输出层的激活函数使用Tanh外,其他隐藏层的激活函数使用ReLU
e. 在判别器中,多所有层的激活函数使用了LeakyReLU。
基于以上的设置,保证了GAN模型的训练的效果,克服了GAN模型存在训练不稳定性的问题。
三、DCGAN实现Demo
在本节中,我们以github上的DCGAN开源代码为例(https://github.com/carpedm20/DCGAN-tensorflow),讲解一下基于Tensorflow架构的DCGAN实现Demo.
1. 数据集
中国香港中文大学汤晓鸥教授实验室公布的大型人脸识别数据集: Large-scale CelebFaces Attributes (CelebA) Dataset 10K 名人,202K 脸部图像,每个图像40余标注属性。
2. 生成器
生成器的核心代码如下:
l 具体为通过对100的高斯噪声数据,先通过线性变化和reshape转化,转化[-1, s_h16, s_w16, self.gf_dim * 8]的向量。
l 然后通过多次的deconv2d反卷积运算,使其反卷积为[64,64,64,3] 根据RGB的原理,即可生成64张64*64的图片。
3. 判别器
通过利用CNN卷积操作,将64*64*3通过卷积和线性变换转化为1维数,然后通激活函数sigmoid来作为输出层的激活输出。
4. 训练调优
如下所示,通过随机噪声数据z,可以通过生成器生成G,通过输入inputs图片,利用判别器可以得到D和D_logits.
如上所示,通过判别器将真实图片判定为真时的误差,可以得d_loss_real.通过生成器生成的图片G,然后利用判别器进行判定后,可以得出将假图片判定为假图片时的误差是d_loss_fake, 通过利用两者的和可以获得真实判别器的误差d_loss. 利用生成器生成的图片,通过判别器判定后的记过D_logits_,可以得出生成器生成的图片与真实图片之间的误差g_loss.
然后通过不断的基于epch和batch循环的迭代更新,不断优化d_loss和g_loss。使整体判别器和生成器达到一个较好的结果。
参考资料
文章来源:德塔大数据研究院
文章编辑:田芳