首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

5分钟入门GANS:原理解释和keras代码实现

本篇文章包含以下内容

介绍

历史

直观解释

训练过程

GAN在MNIST数据集上的KERAS实现

介绍

生成式对抗网络通常也称为GANs,用于生成图像而不需要很少或没有输入。GANs允许我们生成由神经网络生成的图像。在我们深入讨论这个理论之前,我想向您展示GANs构建您兴奋感的能力。把马变成斑马(反之亦然)。

历史

生成式对抗网络(GANs)是由Ian Goodfellow (GANs的GAN Father)等人于2014年在其题为“生成式对抗网络”的论文中提出的。它是一种可替代的自适应变分编码器(VAEs)学习图像的潜在空间,以生成合成图像。它的目的是创造逼真的人工图像,几乎无法与真实的图像区分。

GAN的直观解释

生成器和鉴别器网络:

生成器网络的目的是将随机图像初始化并解码成一个合成图像。

鉴别器网络的目的是获取这个输入,并预测这个图像是来自真实的数据集还是合成的。

正如我们刚才看到的,这实际上就是GANs,两个相互竞争的对抗网络。

GAN的训练过程

GANS的训练是出了名的困难。在CNN中,我们使用梯度下降来改变权重以减少损失。

然而,在GANs中,每一次重量的变化都会改变整个动态系统的平衡。

在GAN的网络中,我们不是在寻求将损失最小化,而是在我们对立的两个网络之间找到一种平衡。

我们将过程总结如下

输入随机生成的噪声图像到我们的生成器网络中生成样本图像。

我们从真实数据中提取一些样本图像,并将其与一些生成的图像混合在一起。

将这些混合图像输入到我们的鉴别器中,鉴别器将对这个混合集进行训练并相应地更新它的权重。

然后我们制作更多的假图像,并将它们输入到鉴别器中,但是我们将它们标记为真实的。这样做是为了训练生成器。我们在这个阶段冻结了鉴别器的权值(鉴别器学习停止),并且我们使用来自鉴别器的反馈来更新生成器的权值。这就是我们如何教我们的生成器(制作更好的合成图像)和鉴别器更好地识别赝品的方法。

流程图如下

对于本文,我们将使用MNIST数据集生成手写数字。GAN的架构是:

使用KERAS实现GANS

首先,我们加载所有必要的库

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np

from tqdm import tqdm

import matplotlib.pyplot as plt

from keras.layers import Input

from keras.models import Model, Sequential

from keras.layers.core import Reshape, Dense, Dropout, Flatten

from keras.layers.advanced_activations import LeakyReLU

from keras.layers.convolutional import Convolution2D, UpSampling2D

from keras.layers.normalization import BatchNormalization

from keras.datasets import mnist

from keras.optimizers import Adam

from keras import backend as K

from keras import initializers

K.set_image_dim_ordering('th')

# Deterministic output.

# Tired of seeing the same results every time? Remove the line below.

np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.

# The dimensionality has been left at 100 for consistency with other GAN implementations.

randomDim = 100

现在我们加载数据集。这里使用MNIST数据集,所以不需要单独下载和处理。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = (X_train.astype(np.float32) - 127.5)/127.5

X_train = X_train.reshape(60000, 784)

接下来,我们定义生成器和鉴别器的结构

# Optimizer

adam = Adam(lr=0.0002, beta_1=0.5)#generator

generator = Sequential()

generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))

generator.add(LeakyReLU(0.2))

generator.add(Dense(512))

generator.add(LeakyReLU(0.2))

generator.add(Dense(1024))

generator.add(LeakyReLU(0.2))

generator.add(Dense(784, activation='tanh'))

generator.compile(loss='binary_crossentropy', optimizer=adam)#discriminator

discriminator = Sequential()

discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(512))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(256))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(1, activation='sigmoid'))

discriminator.compile(loss='binary_crossentropy', optimizer=adam)

现在我们把发生器和鉴别器结合起来同时训练。

# Combined network

discriminator.trainable = False

ganInput = Input(shape=(randomDim,))

x = generator(ganInput)

ganOutput = discriminator(x)

gan = Model(inputs=ganInput, outputs=ganOutput)

gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []

gLosses = []

三个函数,每20个epoch绘制并保存结果,并保存模型。

# Plot the loss from each batch

def plotLoss(epoch):

plt.figure(figsize=(10, 8))

plt.plot(dLosses, label='Discriminitive loss')

plt.plot(gLosses, label='Generative loss')

plt.xlabel('Epoch')

plt.ylabel('Loss')

plt.legend()

plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images

def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):

noise = np.random.normal(0, 1, size=[examples, randomDim])

generatedImages = generator.predict(noise)

generatedImages = generatedImages.reshape(examples, 28, 28)

plt.figure(figsize=figsize)

for i in range(generatedImages.shape[0]):

    plt.subplot(dim[0], dim[1], i+1)

    plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')

    plt.axis('off')

plt.tight_layout()

plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use

def saveModels(epoch):

generator.save('models/gan_generator_epoch_%d.h5' % epoch)

discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)

训练函数

def train(epochs=1, batchSize=128):

batchCount = X_train.shape[0] / batchSize

print 'Epochs:', epochs

print 'Batch size:', batchSize

print 'Batches per epoch:', batchCount

for e in xrange(1, epochs+1):

    print '-'*15, 'Epoch %d' % e, '-'*15

    for _ in tqdm(xrange(batchCount)):

        # Get a random set of input noise and images

        noise = np.random.normal(0, 1, size=[batchSize, randomDim])

        imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

        # Generate fake MNIST images

        generatedImages = generator.predict(noise)

        # print np.shape(imageBatch), np.shape(generatedImages)

        X = np.concatenate([imageBatch, generatedImages])

        # Labels for generated and real data

        yDis = np.zeros(2*batchSize)

        # One-sided label smoothing

        yDis[:batchSize] = 0.9

        # Train discriminator

        discriminator.trainable = True

        dloss = discriminator.train_on_batch(X, yDis)

        # Train generator

        noise = np.random.normal(0, 1, size=[batchSize, randomDim])

        yGen = np.ones(batchSize)

        discriminator.trainable = False

        gloss = gan.train_on_batch(noise, yGen)

    # Store loss of most recent batch from this epoch

    dLosses.append(dloss)

    gLosses.append(gloss)

    if e == 1 or e % 20 == 0:

        plotGeneratedImages(e)

        saveModels(e)

# Plot losses from every epoch

plotLoss(e)

至此一个简单的GAN已经完成了,完整的代码在这里找到

https://github.com/bhaveshgoyal27/mediumblogs/blob/master/Keras_MNIST_GAN.py

作者:Bhavesh Goyal

deephub翻译组

DeepHub

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20200823A02YNN00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券