前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用keras创建一个简单的生成式对抗网络(GAN)

使用keras创建一个简单的生成式对抗网络(GAN)

作者头像
AiTechYun
发布2018-07-27 10:14:15
2.2K0
发布2018-07-27 10:14:15
举报
文章被收录于专栏:ATYUN订阅号ATYUN订阅号

AiTechYun

编辑:yxy

在本教程中,你将了解什么是生成式对抗网络(GAN),但在这里我不会讲解数学细节。在教程的最后,你会学习如何编写一个可以创建数字的简单生成式对抗网络(GAN)!

用比喻方法理解生成式对抗网络GAN

理解生成式对抗网络GAN的最简单方法是通过一个简单的比喻:

假设有一家商店从顾客那里购买某些种类的葡萄酒,之后进行再销售。

然而,有些恶意的顾客为了获得金钱而出售假酒。在这种情况下,店主必须能够区分假酒和正品葡萄酒。

可以想象,最初,伪造者在尝试出售假酒时可能会犯很多错误,并且店主很容易认定该酒是假的。由于这些失败,伪造者会继续尝试使用不同的技术来模拟真正的葡萄酒,有些最终会成功。现在,伪造者知道某些技术已经可以骗过店主的检查,他可以开始根据这些技术进一步改进假冒葡萄酒。

同时,店主可能会从其他店主或葡萄酒专家那里得到一些反馈,说明她拥有的一些葡萄酒不是原装的。这意味着店主必须改善她鉴定葡萄酒的技术。伪造者的目标是制造与真实葡萄酒无法区分的葡萄酒,而店主的目标是准确地分辨葡萄酒是否真实。

这种来回的竞争是GAN背后的主要思想。

生成式对抗网络的组成部分

使用上面的例子,我们可以想出一个生成式对抗网络GAN的架构。

GAN中有两个主要组件:生成器和鉴别器。这个比喻中的店主被称为鉴别器网络,并且通常是卷积神经网络(因为GAN主要用于图像任务),其指定图像为真实的概率。

伪造者一般称为生成网络,并且通常也是卷积神经网络(具有反卷积层)。该网络需要一些噪声矢量并输出图像。在训练生成网络时,它会学习图像的哪些区域进行改进/更改,以便鉴别器将难以将其生成的图像与真实图像区分开来。

生成网络不断生成更接近真实图像的图像,而鉴别网络试图确定真实图像和假图像之间的差异。最终的目标是建立一个可生成与真实图像无法区分的图像的生成网络。

使用Keras做一个简单的生成式对抗网络GAN

现在你已了解生成式对抗网络GAN是什么以及它们的主要组成部分,现在我们可以开始使用Keras编写一个非常简单的代码。本教程基于这里开发的非常酷且易于理解的GAN 。

你需要做的第一件事是使用pip安装以下包:

代码语言:javascript
复制
-  keras
代码语言:javascript
复制
-  matplotlib
代码语言:javascript
复制
-  tensorflow
代码语言:javascript
复制
-  tqdm

你将matplotlib用于绘制tensorflowKeras后端库,并tqdm为每个周期(迭代)显示一个进度条。

下一步是创建一个Python脚本。在这个脚本中,你首先需要导入你将要使用的所有模块和函数。在使用它们时给出每个解释。

代码语言:javascript
复制
import os
代码语言:javascript
复制
import numpy as np
代码语言:javascript
复制
import matplotlib.pyplot as plt
代码语言:javascript
复制
from tqdmimport tqdm
代码语言:javascript
复制
代码语言:javascript
复制
代码语言:javascript
复制
from keras.layersimport Input
代码语言:javascript
复制
from keras.modelsimport Model, Sequential
代码语言:javascript
复制
from keras.layers.coreimport Dense, Dropout
代码语言:javascript
复制
from keras.layers.advanced_activationsimport LeakyReLU
代码语言:javascript
复制
from keras.datasetsimport mnist
代码语言:javascript
复制
from keras.optimizersimport Adam
代码语言:javascript
复制
from kerasimport initializers

你现在想要设置一些变量:

代码语言:javascript
复制
# Let Keras know that we are using tensorflow as our backend engine
代码语言:javascript
复制
os.environ["KERAS_BACKEND"]= "tensorflow"
代码语言:javascript
复制
# To make sure that we can reproduce the experiment and get the same results
代码语言:javascript
复制
np.random.seed(10)
代码语言:javascript
复制
# The dimension of our random noise vector.
代码语言:javascript
复制
random_dim= 100

在开始构建鉴别器和生成器之前,你应该首先收集并预处理数据。你将使用流行的MNIST数据集,该数据集具有一组从0到9范围内的单个数字的图像。

MINST数字的例子

代码语言:javascript
复制
def load_minst_data():
代码语言:javascript
复制
    # load the data
代码语言:javascript
复制
    (x_train, y_train), (x_test, y_test)= mnist.load_data()
代码语言:javascript
复制
    # normalize our inputs to be in the range[-1, 1]
代码语言:javascript
复制
    x_train= (x_train.astype(np.float32)- 127.5)/127.5
代码语言:javascript
复制
    # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have
代码语言:javascript
复制
    # 784 columns per row
代码语言:javascript
复制
    x_train= x_train.reshape(60000,784)
代码语言:javascript
复制
    return (x_train, y_train, x_test, y_test)

请注意,这mnist.load_data()是Keras的一部分,并允许你轻松将MNIST数据集导入你的工作区。

现在,你可以创建你的生成器和鉴别器网络。你将为这两个网络使用Adam优化器。对于生成器和鉴别器,你将创建一个带有三个隐藏层的神经网络,激活函数为Leaky Relu。你还应该为鉴别器添加辍学层(dropout layers),以提高其对未见图像的鲁棒性。

代码语言:javascript
复制
def get_optimizer():
代码语言:javascript
复制
    return Adam(lr=0.0002, beta_1=0.5)
代码语言:javascript
复制
代码语言:javascript
复制
def get_generator(optimizer):
代码语言:javascript
复制
    generator= Sequential()
代码语言:javascript
复制
    generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
代码语言:javascript
复制
    generator.add(LeakyReLU(0.2))
代码语言:javascript
复制
代码语言:javascript
复制
    generator.add(Dense(512))
代码语言:javascript
复制
    generator.add(LeakyReLU(0.2))
代码语言:javascript
复制
代码语言:javascript
复制
    generator.add(Dense(1024))
代码语言:javascript
复制
    generator.add(LeakyReLU(0.2))
代码语言:javascript
复制
代码语言:javascript
复制
    generator.add(Dense(784, activation='tanh'))
代码语言:javascript
复制
    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
代码语言:javascript
复制
    return generator
代码语言:javascript
复制
代码语言:javascript
复制
def get_discriminator(optimizer):
代码语言:javascript
复制
    discriminator= Sequential()
代码语言:javascript
复制
    discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
代码语言:javascript
复制
    discriminator.add(LeakyReLU(0.2))
代码语言:javascript
复制
    discriminator.add(Dropout(0.3))
代码语言:javascript
复制
代码语言:javascript
复制
    discriminator.add(Dense(512))
代码语言:javascript
复制
    discriminator.add(LeakyReLU(0.2))
代码语言:javascript
复制
    discriminator.add(Dropout(0.3))
代码语言:javascript
复制
代码语言:javascript
复制
    discriminator.add(Dense(256))
代码语言:javascript
复制
    discriminator.add(LeakyReLU(0.2))
代码语言:javascript
复制
    discriminator.add(Dropout(0.3))
代码语言:javascript
复制
代码语言:javascript
复制
    discriminator.add(Dense(1, activation='sigmoid'))
代码语言:javascript
复制
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
代码语言:javascript
复制
    return discriminator

终于到了将生成器和鉴别器放在一起的时候了!

代码语言:javascript
复制
def get_gan_network(discriminator, random_dim, generator, optimizer):
代码语言:javascript
复制
    # We initially set trainable to False since we only want to train either the
代码语言:javascript
复制
    # generator or discriminator at a time
代码语言:javascript
复制
    discriminator.trainable= False
代码语言:javascript
复制
    # gan input (noise) will be 100-dimensional vectors
代码语言:javascript
复制
    gan_input= Input(shape=(random_dim,))
代码语言:javascript
复制
    # the output of the generator (an image)
代码语言:javascript
复制
    x= generator(gan_input)
代码语言:javascript
复制
    # get the output of the discriminator (probability if the image is real or not)
代码语言:javascript
复制
    gan_output= discriminator(x)
代码语言:javascript
复制
    gan= Model(inputs=gan_input, outputs=gan_output)
代码语言:javascript
复制
    gan.compile(loss='binary_crossentropy', optimizer=optimizer)
代码语言:javascript
复制
    return gan

为了保持完备性,你可以创建一个功能,每20个周期保存你生成的图像。但由于这不是教程的核心,所以你不需要完全理解它。

代码语言:javascript
复制
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
代码语言:javascript
复制
    noise= np.random.normal(0,1, size=[examples, random_dim])
代码语言:javascript
复制
    generated_images= generator.predict(noise)
代码语言:javascript
复制
    generated_images= generated_images.reshape(examples,28,28)
代码语言:javascript
复制
代码语言:javascript
复制
    plt.figure(figsize=figsize)
代码语言:javascript
复制
    for iin range(generated_images.shape[0]):
代码语言:javascript
复制
        plt.subplot(dim[0], dim[1], i+1)
代码语言:javascript
复制
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
代码语言:javascript
复制
        plt.axis('off')
代码语言:javascript
复制
    plt.tight_layout()
代码语言:javascript
复制
    plt.savefig('gan_generated_image_epoch_%d.png' % epoch)

你现在已经编码了大部分网络。剩下的就是训练这个网络,并看看你创建的图像。

代码语言:javascript
复制
def train(epochs=1, batch_size=128):
代码语言:javascript
复制
    # Get the training and testing data
代码语言:javascript
复制
    x_train, y_train, x_test, y_test= load_minst_data()
代码语言:javascript
复制
    # Split the training data into batches of size 128
代码语言:javascript
复制
    batch_count= x_train.shape[0]/ batch_size
代码语言:javascript
复制
代码语言:javascript
复制
    # Build our GAN netowrk
代码语言:javascript
复制
    adam= get_optimizer()
代码语言:javascript
复制
    generator= get_generator(adam)
代码语言:javascript
复制
    discriminator= get_discriminator(adam)
代码语言:javascript
复制
    gan= get_gan_network(discriminator, random_dim, generator, adam)
代码语言:javascript
复制
代码语言:javascript
复制
    for ein xrange(1, epochs+1):
代码语言:javascript
复制
        print '-'*15,'Epoch %d' % e,'-'*15
代码语言:javascript
复制
        for _in tqdm(xrange(batch_count)):
代码语言:javascript
复制
            # Get a random set of input noise and images
代码语言:javascript
复制
            noise= np.random.normal(0,1, size=[batch_size, random_dim])
代码语言:javascript
复制
            image_batch= x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

代码语言:javascript
复制
            # Generate fake MNIST images
代码语言:javascript
复制
            generated_images= generator.predict(noise)
代码语言:javascript
复制
            X= np.concatenate([image_batch, generated_images])
代码语言:javascript
复制
代码语言:javascript
复制
            # Labels for generated and real data
代码语言:javascript
复制
            y_dis= np.zeros(2*batch_size)
代码语言:javascript
复制
            # One-sided label smoothing
代码语言:javascript
复制
            y_dis[:batch_size]= 0.9
代码语言:javascript
复制
代码语言:javascript
复制
            # Train discriminator
代码语言:javascript
复制
            discriminator.trainable= True
代码语言:javascript
复制
            discriminator.train_on_batch(X, y_dis)
代码语言:javascript
复制
代码语言:javascript
复制
            # Train generator
代码语言:javascript
复制
            noise= np.random.normal(0,1, size=[batch_size, random_dim])
代码语言:javascript
复制
            y_gen= np.ones(batch_size)
代码语言:javascript
复制
            discriminator.trainable= False
代码语言:javascript
复制
            gan.train_on_batch(noise, y_gen)
代码语言:javascript
复制
代码语言:javascript
复制
        if e== 1 or e% 20 == 0:
代码语言:javascript
复制
            plot_generated_images(e, generator)
代码语言:javascript
复制
代码语言:javascript
复制
if __name__== '__main__':
代码语言:javascript
复制
    train(400,128)

训练400个周期后,你可以查看生成的图像。查看第一个周期后产生的图像,可以看到它没有任何真实的结构,在40个周期后查看图像,数字开始成形,最后,400个周期后产生的图像显示出清晰的数字,尽管有一些仍然无法辨认。

此代码在CPU上每个周期大约需要2分钟,这是选择此代码的主要原因。你可以尝试训练更多的周期,并通过向生成器和鉴别器添加更多层。但是,当使用更复杂和更深的架构时,如果仅使用CPU,则运行时也会增加。但是,不要让这阻止你尝试!

结论

恭喜,你已经完成了本教程的最后部分,你将以直观的方式学习生成式对抗网络(GAN)的基础知识!另外,你在Keras库的帮助下实现了这个模型。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-05-15,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 ATYUN订阅号 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 用比喻方法理解生成式对抗网络GAN
  • 生成式对抗网络的组成部分
  • 使用Keras做一个简单的生成式对抗网络GAN
  • 结论
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档