1. 概述
生成对抗网络GAN(Generative adversarial nets)[1]是由Goodfellow等人于2014年提出的基于深度学习模型的生成框架,可用于多种生成任务。从名称也不难看出,在GAN中包括了两个部分,分别为”生成”和“对抗”,整两个部分也分别对应了两个网络,即生成网络(Generator)G和判别网络(Discriminator)D ,为描述简单,以图像生成为例:
在GAN框架的训练过程中,希望生成网络G生成的图片尽量真实,能够欺骗过判别网络D ;而希望判别网络D 能够把G 生成的图片从真实图片中区分开。这样的一个过程就构成了一个动态的“博弈”。最终,GAN希望能够使得训练好的生成网络G生成的图片能够以假乱真,即对于判别网络D来说,无法判断G 生成的网络是不是真实的。
综上,训练好的生成网络G 便可以用于生成“以假乱真”的图片。
2. 算法原理
2.1. GAN的框架结构
GAN的框架是由生成网络G 和判别网络D这两种网络结构组成,通过两种网络的“对抗”过程完成两个网络的训练,GAN框架由下图所示:
由生成网络G 生成一张“Fake image”,判别网络D 判断这张图片是否来自真实图片。
2.2. GAN框架的训练过程
在GAN的训练过程中,其最终的目标是使得训练出来的生成模型G 生成的图片与真实图片具有相同的分布,其过程可通过下图描述[2]:
2.3. 价值函数
简化后为:
\underset{G}{min}\; \underset{D}{max}\; V\left ( D,G \right )\approx \sum_{i=1}^{m}\left [ log\; D\left ( \boldsymbol{x}^{\left ( i \right )} \right ) + log\; \left ( 1-D\left ( G\left ( \boldsymbol{z}^{\left ( i \right )} \right ) \right ) \right ) \right ]上述的交替训练过程如下流程所示:
3. GAN背后的数学原理
3.1. 为什么会有这样的价值函数
3.2. KL散度
需要刻画两个分布是否相似,需要用到KL散度(KL divergence)。KL散度是统计学中的一个基本概念,用于衡量两个分布的相似程度,数值越小,表示两种概率分布越接近。对于离散的概率分布,定义如下:
D_{KL}\left ( P\parallel Q \right )=\sum_{i}P\left ( i \right )log\frac{P\left ( i \right )}{Q\left ( i \right )}对于连续的概率分布,定义如下:
D_{KL}\left ( P\parallel Q \right )=\int_{-\infty }^{+\infty }p\left ( x \right )log\frac{p\left ( x \right )}{q\left ( x \right )}dx3.3. 极大似然估计
3.4. 收敛性分析
当生成网络G 确定后,价值函数可以表示为:
\begin{aligned} \underset{D}{max}\; V\left ( D \right )&=\int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; \left ( D\left ( \boldsymbol{x} \right ) \right ) d\boldsymbol{x}+\int _{\boldsymbol{z}} p_{\boldsymbol{z}}\left ( \boldsymbol{z} \right )log\; \left ( 1-D\left ( G\left ( \boldsymbol{z} \right ) \right ) \right )d\boldsymbol{z} \\ &= \int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; \left ( D\left ( \boldsymbol{x} \right ) \right )+p_{g}\left ( \boldsymbol{x} \right )log\; \left ( 1-D\left ( \boldsymbol{x} \right ) \right ) d\boldsymbol{x} \end{aligned}由于上述的积分与D 无关,上述可以简化成求解:
\underset{D}{max}\left [ p_{data}\left ( \boldsymbol{x} \right )log\; \left ( D\left ( \boldsymbol{x} \right ) \right )+p_{g}\left ( \boldsymbol{x} \right )log\; \left ( 1-D\left ( \boldsymbol{x} \right ) \right )\right ]求导数并令其为0 ,便可以得到最大的D:
D^{\ast }=\frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )}且D\left ( \boldsymbol{x} \right )\in \left [ 0,1 \right ],将其带入到价值函数中,可得
V\left ( D^{\ast },G \right )=\mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; \frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ]+\mathbb{E}_{\boldsymbol{x}\sim p_g\left ( \boldsymbol{x} \right )}\left [ log\; \left ( 1-\frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ) \right ]对上式简化,可得:
\begin{aligned} V\left ( D^{\ast },G \right ) &= \mathbb{E}_{\boldsymbol{x}\sim p_{data}\left ( \boldsymbol{x} \right )}\left [ log\; \frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ]+\mathbb{E}_{\boldsymbol{x}\sim p_g\left ( \boldsymbol{x} \right )}\left [ log\; \frac{p_g\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )} \right ]\\ &= \int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; \frac{p_{data}\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )}d\boldsymbol{x}+\int _{\boldsymbol{x}}p_g\left ( \boldsymbol{x} \right )log\; \frac{p_g\left ( \boldsymbol{x} \right )}{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )}d\boldsymbol{x} \end{aligned}通过对分子和分母分别除2,可得:
\begin{aligned} V\left ( D^{\ast },G \right ) &= \int _{\boldsymbol{x}}p_{data}\left ( \boldsymbol{x} \right )log\; \frac{\frac{1}{2}p_{data}\left ( \boldsymbol{x} \right )}{\frac{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )}{2}}d\boldsymbol{x}+\int _{\boldsymbol{x}}p_g\left ( \boldsymbol{x} \right )log\; \frac{\frac{1}{2}p_g\left ( \boldsymbol{x} \right )}{\frac{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )}{2}}d\boldsymbol{x}\\ &= -2log\; 2+KL\left ( p_{data}\left ( \boldsymbol{x} \right )\parallel \frac{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )}{2} \right )+KL\left ( p_g\left ( \boldsymbol{x} \right )\parallel \frac{p_{data}\left ( \boldsymbol{x} \right )+p_{g}\left ( \boldsymbol{x} \right )}{2} \right ) \end{aligned}这里引入另一个符号:JS散度(Jensen-Shannon Divergence)
JSD\left ( P\parallel Q \right )=\frac{1}{2}\left [ KL\left ( P\parallel M \right ) + KL\left ( Q\parallel M \right )\right ]其中,M=\frac{P+Q}{2} 。因此V\left ( D^{\ast },G \right )可以表示为:
V\left ( D^{\ast },G \right )=-2log\; 2+2JSD\left ( p_{data}\left ( \boldsymbol{x} \right )\parallel p_g\left ( \boldsymbol{x} \right ) \right )3. 总结
生成对抗网络GAN中通过生成网络G和判别网络D之间的“生成”和“对抗”过程,通过多次的迭代,最终达到平衡,使得训练出来的生成网络G能够生成“以假乱真”的数据,判别网络D不能将其从真实数据中区分开。
参考文献
[1] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[J]. Advances in neural information processing systems, 2014, 27.
[2] Generative Models
[2] PyTorch 学习笔记(十):初识生成对抗网络(GANs)
[3] 通俗理解生成对抗网络GAN