前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >GAN对抗网络入门教程

GAN对抗网络入门教程

作者头像
致Great
发布2019-09-18 16:15:29
1.5K0
发布2019-09-18 16:15:29
举报
文章被收录于专栏:程序生活程序生活

译:A Beginner's Guide to Generative Adversarial Networks (GANs) https://skymind.ai/wiki/generative-adversarial-network-gan

1 GAN简介

生成对抗网络(英语:Generative Adversarial Network,简称GAN)是非监督式学习的一种方法,通过让两个神经网络相互博弈的方式进行学习。该方法由伊恩·古德费洛等人于2014年提出。生成对抗网络由一个生成网络与一个判别网络组成。生成网络从潜在空间(latent space)中随机取样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入则为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能分辨出来。而生成网络则要尽可能地欺骗判别网络。两个网络相互对抗、不断调整参数,最终目的是使判别网络无法判断生成网络的输出结果是否真实。

生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片、三维物体模型等。

虽然生成对抗网络原先是为了无监督学习提出的,它也被证明对半监督学习、完全监督学习 、强化学习是有用的。

image

2 生成与判别算法

要理解GAN,你应该知道生成算法是如何工作的,但是在理解生成算法之前,将它们与判别算法进行对比可以加深理解。我们先看下什么事判别算法?

判别算法试图对输入数据进行分类; 也就是说,给定数据实例的特征,它们预测该数据所属的标签或类别。

例如,给定电子邮件中的所有单词(数据实例),判别算法可以预测该消息是spam(垃圾邮件)还是not_spam(非垃圾邮件)。 其中spam是标签之一,从电子邮件收集的单词包是构成输入数据的特征。 当以数学方式表达此问题时,标签称为y,并且要素称为x。公式p(y|x)用于表示“给定x条件下y发生的概率”,在这种情况下,它将转换为“在给定邮件所包含的字词情况下,电子邮件是垃圾邮件的概率”。

因此,判别算法是将特征映射到标签,而生成算法恰恰在做相反的事情。生成算法试图预测给定某个标签下的特征,而不是预测给定某些特征的标签。

生成算法试图回答的问题是:假设这封电子邮件是垃圾邮件,特征的分布或者概率是怎么样的? 虽然判别模型关注y和x之间的关系,但是生成模型关心“你如何得到x。”生成算法是为了计算出(x | y),给出y条件下x发生的概率,或者说给出标签时,特征的概率。 (也就是说,生成算法也可以用作分类器。恰好它们不是对输入数据进行分类。)

下面两句话将判别与生成区分开来:

  • 判别模型学习了类之间的界限
  • 生成模型模拟各个类的分布

3 GANs原理

GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:一个神经网络,称为生成器,生成新的数据实例,而另一个神经网络,判别器,评估它们的真实性; 即判别器决定它所评测的每个数据实例是否属于实际训练数据集。

G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。 D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。 在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。

最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。

reference:https://zhuanlan.zhihu.com/p/24767059

以下是GAN大致步骤:

  • 生成器接收随机数并返回图像。
  • 将生成的图像与从真实数据集中获取的图像流一起馈送到判别器中。
  • 判别器接收真实和假图像并返回概率,0到1之间的数字,1表示真实性的预测,0表示假。

image

您可以将GAN视为诈骗者和警察在猫与老鼠游戏中的反对,其中诈骗者正在学习传递虚假信息,并且警察正在学习如何检测它们。 两者都是动态的; 也就是说,警察也在接受培训,每一方都在不断升级中学习对方的方法。

对于MNIST数据集,判别器网络是标准卷积网络,可以对馈送给它的图像进行分类,二项分类器将图像标记为真实或伪造。 在某种意义上,生成器是反卷积网络:当标准卷积分类器采用图像并对其进行下采样以产生概率时,生成器采用随机噪声矢量并将其上采样到图像。 第一个通过下采样技术(如maxpooling)丢弃数据,第二个生成新数据。

image

4 GANs, Autoencoders and VAEs

下面对生成性对抗网络与其他神经网络(例如自动编码器和变分自动编码器)进行比较。

自动编码器将输入数据编码为矢量。它们创建原始数据的隐藏或压缩表示,在减少维数方面很有用; 也就是说,用作隐藏表示的向量将原始数据压缩为较少数量的突出维度。 自动编码器可以与所谓的解码器配对,允许您根据其隐藏的表示重建输入数据,就像使用受限制的Boltzmann机器一样。

image

变分自动编码器是生成算法,其为编码输入数据添加额外约束,即隐藏表示被标准化。 变分自动编码器能够像自动编码器一样压缩数据并像GAN一样合成数据。 然而GAN可以更精细、细粒度的生成数据,VAE生成的图像往往更加模糊。 Deeplearning4j的例子包括自动编码器和变分自动编码器。(https://github.com/deeplearning4j/dl4j-examples/tree/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/unsupervised

5 Keras 实现GAN

https://github.com/eriklindernoren/Keras-GAN

代码语言:javascript
复制
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np
代码语言:javascript
复制
Using TensorFlow backend.
代码语言:javascript
复制
class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build and compile the generator
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        # The generator takes noise as input and generated imgs
        z = Input(shape=(100,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator) takes
        # noise as input => generates images => determines validity
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        noise_shape = (100,)

        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)

        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # Plot the progress
            if epoch%1000==0:
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("data/gan/images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, save_interval=200)
代码语言:javascript
复制
WARNING:tensorflow:From D:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:66: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

WARNING:tensorflow:From D:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:541: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From D:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:4432: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 257       
=================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
WARNING:tensorflow:From D:\ProgramData\Anaconda3\lib\site-packages\keras\optimizers.py:793: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

WARNING:tensorflow:From D:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:3657: The name tf.log is deprecated. Please use tf.math.log instead.

WARNING:tensorflow:From D:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
WARNING:tensorflow:From D:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:148: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead.

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024      
_________________________________________________________________
dense_5 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dense_6 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_7 (Dense)              (None, 784)               803600    
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584
_________________________________________________________________


D:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py:493: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.735185, acc.: 46.88%] [G loss: 0.829077]


D:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py:493: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
  'Discrepancy between trainable weights and collected trainable'


1000 [D loss: 0.590758, acc.: 71.88%] [G loss: 0.793450]
2000 [D loss: 0.587990, acc.: 62.50%] [G loss: 0.956186]
3000 [D loss: 0.644352, acc.: 59.38%] [G loss: 0.914777]
4000 [D loss: 0.673936, acc.: 62.50%] [G loss: 0.971460]
5000 [D loss: 0.759974, acc.: 53.12%] [G loss: 0.904706]
6000 [D loss: 0.555306, acc.: 81.25%] [G loss: 0.835633]
7000 [D loss: 0.674409, acc.: 62.50%] [G loss: 0.823623]
8000 [D loss: 0.672854, acc.: 53.12%] [G loss: 0.863680]
9000 [D loss: 0.743683, acc.: 46.88%] [G loss: 0.868321]
10000 [D loss: 0.635190, acc.: 59.38%] [G loss: 0.854181]
11000 [D loss: 0.700397, acc.: 56.25%] [G loss: 0.778778]
12000 [D loss: 0.741978, acc.: 46.88%] [G loss: 0.813542]
13000 [D loss: 0.760614, acc.: 46.88%] [G loss: 0.833507]
14000 [D loss: 0.671199, acc.: 68.75%] [G loss: 0.853395]
15000 [D loss: 0.676217, acc.: 62.50%] [G loss: 0.920993]
16000 [D loss: 0.593898, acc.: 68.75%] [G loss: 0.889001]
17000 [D loss: 0.724363, acc.: 50.00%] [G loss: 0.893431]
18000 [D loss: 0.779740, acc.: 43.75%] [G loss: 0.853765]
19000 [D loss: 0.642237, acc.: 59.38%] [G loss: 0.830348]
20000 [D loss: 0.587237, acc.: 62.50%] [G loss: 0.876839]
21000 [D loss: 0.645381, acc.: 62.50%] [G loss: 0.827465]
22000 [D loss: 0.723597, acc.: 46.88%] [G loss: 0.862281]
23000 [D loss: 0.671319, acc.: 65.62%] [G loss: 0.903444]
24000 [D loss: 0.684801, acc.: 62.50%] [G loss: 0.807403]
25000 [D loss: 0.737355, acc.: 43.75%] [G loss: 0.813877]
26000 [D loss: 0.606201, acc.: 68.75%] [G loss: 0.802509]
27000 [D loss: 0.711020, acc.: 56.25%] [G loss: 0.894887]
28000 [D loss: 0.641023, acc.: 56.25%] [G loss: 0.856079]
29000 [D loss: 0.696889, acc.: 46.88%] [G loss: 0.728626]

可以看到D的判别准确率最终在46%-56%之间,也就是说G网络生成的图片已经真假难分

6 参考资料

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019.09.17 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 GAN简介
  • 2 生成与判别算法
  • 3 GANs原理
  • 4 GANs, Autoencoders and VAEs
  • 5 Keras 实现GAN
  • 6 参考资料
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档