Generative Adversarial Networks

Generative Adversarial Networks

上一篇讲述了VAEs(变分自编码器),那么这次继续学习一下另一个生成模型——GANs。这里建议如果没有看VAEs的请点击传送门:,因为有所关联,所以如果直接看这篇的话,开头会有点奇怪。

从VAEs继续。

如果我们并不想明确模型的密度分布,我们只把注意点放在抽样生成一个新的实例呢?

但问题是,我们无法直接从复杂的分布中抽样。解决办法就是,我们从一个简单的噪声分布中抽样,然后基于这个噪声去训练一个分布。

那如何去使得生成的实例跟我们输入的实例尽可能相近呢?

加入一个鉴别器!鉴别器的功能就是来区分输入判别网络的实例是真实的样本,还是由生成网络生成的样本。因此,很显然两个网络之间形成了一种很特殊的彼此竞争的关系,生成网络尽可能想生成“骗过”鉴别器的实例,而鉴别器又尽可能想“不被骗”,识别出哪些实例是被生成网络生成的(假的),哪些实例是真实的。

来看个例子

生成器:尽力提高数据的虚假性。

鉴别器:尽力辨别真假。

首先,最初的生成器是由噪声抽样形成的,因此他的点的分布是随机的;而判别器有一个很明确的判别标准,真实点的分布被鉴别为1,生成点被判别为0。

下一步,生成器发现并没有很好地欺骗判别器,因此在下一次迭代的时候,生成器生成的点逐步向判别器判别为真实的点靠近。

下一步,轮到了判别器,由于生成器所生成的点逐步接近真实点,因此判别器可能有点难度去区分。

再下一步,生成器生成的点更接近真实点,甚至已经有几个生成点与真实点几乎重合,这对接下来判别器的判别会带来极大的难度。

如此迭代下去。。。

通过这个过程,我们可以发现,训练这样的两个模型的大方法就是:单独交替迭代训练。什么意思?因为是2个网络,不好一起训练,所以才去交替迭代训练。

如何训练一个GANs

训练GANs需要连续经过一个最大最小优化问题,其实也就是对应上面描述的迭代优化的过程。

鉴别器希望最大化目标,且D(x)趋向于1,D(G(x))趋向于0;而生成器希望最小化目标,且D(G(x))趋向于1。

优化D:

可以看到,优化D的时候,也就是判别网络,其实没有生成网络什么事,后面的G(z)这里就相当于已经得到的假样本。优化D的公式的第一项,使得真样本x输入的时候,得到的结果越大越好,可以理解,因为需要真样本的预测结果越接近于1越好嘛。对于假样本,需要优化是的其结果越小越好,也就是D(G(z))越小越好,因为它的标签为0。但是呢第一项是越大,第二项是越小,这不矛盾了,所以呢把第二项改成1-D(G(z)),这样就是越大越好,两者合起来就是越大越好。

优化G:

那么同样在优化G的时候,这个时候没有真样本什么事,所以把第一项直接去掉了。这个时候只有假样本,但是我们说这个时候是希望假样本的标签是1的,所以是D(G(z))越大越好,但是呢为了统一成1-D(G(z))的形式,那么只能是最小化1-D(G(z)),本质上没有区别,只是为了形式的统一。之后这两个优化模型可以合并起来写,就变成了最开始的那个最大最小目标函数了。

GANs强大的功能

这张图表明的是GAN的生成网络如何一步步从均匀分布学习到高斯分布的。原始数据x服从正太分布,这个过程你也没告诉生成网络说你得用高斯分布来学习,但是生成网络学习到了。假设你改一下x的分布,不管什么分布,生成网络可能也能学到。这就是GAN可以自动学习真实数据的分布的强大之处。

再来看一个,下面是一个改变人脸的过程,最上面一行是输入的图像,通过GANs,我们可以看到对应每一列生成的图像。

代码

from __future__ import absolute_import, division, print_function, unicode_literals
from tensorflow.keras import layers
from IPython import display
import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import time

#数据数量
BUFFER_SIZE = 2656
#批大小
BATCH_SIZE = 16
#迭代次数
EPOCHS = 800
#G的测试噪音大小
noise_dim = 100
#G的测试噪音集大小
num_example_to_G = 4
#dropout_rate
DROPOUT = 0.3
#G的测试集
seed = tf.random.normal([num_example_to_G, noise_dim])



#导入数据集
data = np.load('./paints200.npy')
data = data.astype('float32')
data = (data - 127.5) / 127.5

train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)




#生成器与判别器
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(25*25*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((25, 25, 256)))
    assert model.output_shape == (None, 25, 25, 256) 

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 50, 50, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 100, 100, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 200, 200, 3)

    return model

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[200, 200, 3]))

    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(DROPOUT))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(DROPOUT))

    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(DROPOUT))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model
generator = tf.keras.models.load_model('epoch7_Gmodel.h5')
discriminator = tf.keras.models.load_model('epoch7_Dmodel.h5')




#损失值
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)




#优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)





#训练
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            train_step(image_batch)

            display.clear_output(wait=True)
            generate_and_save_images(generator, epoch+1, seed)
            """
            if (epoch + 1) % 100 == 0:
                generator.save('epoch'+epoch+'_Gmodel.h5')
                discriminator.save('epoch'+epoch+'_Dmodel.h5')
                """

            print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

        generator.save('epoch'+str(epoch+1)+'_Gmodel.h5')
        discriminator.save('epoch'+str(epoch+1)+'_Dmodel.h5')

        display.clear_output(wait=True)
        generate_and_save_images(generator, epochs, seed)

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(2,2))

    for i in range(predictions.shape[0]):
        plt.subplot(2, 2, i+1)
        plt.imshow((predictions[i, :, :, :]*127.5+127.5)/255)
        plt.axis('off')

    if not os.path.exists('picture'):
        os.mkdir('picture')

    plt.savefig('./picture/image_at_epoch_{:04d}.png'.format(epoch))
    #plt.show()



train(train_dataset, EPOCHS)
generator.save('newGmodel.h5')
discriminator.save('newDmodel.h5')

输出结果由gif展现:

本文分享自微信公众号 - 卡尼慕(gh_40138f7dc7d3)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-10-01

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数据猿

AlphaGo之父戴密斯·哈萨比斯:是天才,也是生活里的普通人

Deepmind创始人戴密斯·哈萨比斯是这个时代公认的天才,他不仅是世界AI领域的第一人,还是世界国际象棋大师、电脑游戏设计师、企业家和神经学家,世界范围内再无...

11020
来自专栏计算机视觉战队

选择性细化网络用于高性能人脸检测

人脸检测是自动人脸识别系统中的一个关键环节。早期的人脸识别研究主要针对具有较强约束条件的人脸图象(如无背景的图象),往往假设人脸位置一直或者容易获得,因此人脸检...

8020
来自专栏AI科技大本营的专栏

值得收藏!基于激光雷达数据的深度学习目标检测方法大合集(上)

【导读】上周,我们在《激光雷达,马斯克看不上,却又无可替代?》一文中对自动驾驶中广泛使用的激光雷达进行了简单的科普,今天,这篇文章将各大公司和机构基于激光雷达的...

24830
来自专栏机器之心

80页笔记看遍机器学习基本概念、算法、模型,帮新手少走弯路

本文要介绍的是一份长约 80 页的学习笔记,旨在总结机器学习的一系列基本概念(如梯度下降、反向传播等),不同的机器学习算法和流行模型,以及一些作者在实践中学到的...

9510
来自专栏算法channel

推荐一个神级工具:能缓解梯度消失问题&提升训练速度

深度神经网络一直以来就有一个特点:随着网络加深,模型会越来越难以训练。所以深度学习有一个非常本质性的问题:为什么随着网络加深,训练会越来越困难?为了解...

9430
来自专栏数据分析1480

机器学习中常用的5种回归损失函数,你都用过吗?

“损失函数”是机器学习优化中至关重要的一部分。L1、L2损失函数相信大多数人都早已不陌生。那你了解Huber损失、Log-Cosh损失、以及常用于计算预测区间的...

10640
来自专栏大数据成神之路

大数据开发体系,进来了解一下?

掌握Linux必备知识,熟悉Python的使用与爬虫程序的编写,搭建Hadoop(CDH)集群,为大数据技术学习打好基础。

14560
来自专栏不仅仅是python

机器学习-理解Logistic Regression

本文讨论了Logistic回归的基础知识及其在Python中的实现。逻辑回归基本上是一种监督分类算法。在分类问题中,目标变量(或输出)y对于给定的一组特征(或输...

8220
来自专栏计算机视觉战队

Xgboost - A scalable tree boosting system Chiang

XGBoost(eXtreme Gradient Boosting)其核心是对决策树(Decision Tree)的增强(Boosting)方法,属于集成学习(...

12030
来自专栏程序生活

GAN对抗网络入门教程

译:A Beginner's Guide to Generative Adversarial Networks (GANs) https://skymind.a...

17130

扫码关注云+社区

领取腾讯云代金券

年度创作总结 领取年终奖励