前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >一看就懂的Tensorflow实战(GAN)

一看就懂的Tensorflow实战(GAN)

作者头像
AI异构
发布2020-07-29 11:25:23
6920
发布2020-07-29 11:25:23
举报
文章被收录于专栏:AI异构AI异构

生成对抗网络简介

生成对抗网络(GAN)启发自博弈论中的二人零和博弈(two-player game),类似于周伯通的绝学——“左右互搏”GAN 模型中的两位博弈方分别由生成式模型(generative model)和判别式模型(discriminative model)充当。生成模型 G 捕捉样本数据的分布,用服从某一分布(均匀分布,高斯分布等)的噪声 z 生成一个类似真实训练数据的样本,追求效果是越像真实样本越好;判别模型 D 是一个二分类器,估计一个样本来自于训练数据(而非生成数据)的概率,如果样本来自于真实的训练数据,D 输出大概率,否则,D 输出小概率。可以做如下类比:生成网络 G 好比假币制造团伙,专门制造假币,判别网络 D 好比警察,专门检测使用的货币是真币还是假币,G 的目标是想方设法生成和真币一样的货币,使得 D 判别不出来,D 的目标是想方设法检测出来 G 生成的假币。随着训练时间的增加,判别模型与生成模型的能力都相应的提升!

具体生成网络的示意图如下所示:

生成对抗网络结构示意图

Tensorflow生成对抗网络实现

代码语言:javascript
复制
from __future__ import division, print_function, absolute_import

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
导入数据集
代码语言:javascript
复制
# 导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=True)

代码语言:javascript
复制
Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz
参数设置
代码语言:javascript
复制
# Training Params
num_steps = 70000 #总迭代次数
batch_size = 128  # 批量大小
learning_rate = 0.0002 #学习率

# Network Params
image_dim = 784 # 28*28 pixels,生成器的输出层节点数,也是判别器的输入
gen_hidden_dim = 256 # 生成器隐藏层节点数
disc_hidden_dim = 256 # 判别器隐藏层节点数
noise_dim = 100 # Noise data points 生成器输入节点数

# Xavier 初始化方式(更适合有ReLU的网络训练)
def glorot_init(shape):
    return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))

Xavier 初始化方式方差:

这里的参数是标准差。

设置每一层的权重与偏置
代码语言:javascript
复制
# 设置每一层的权重(Xavier初始化)与偏置(初始化为零)
weights = {
    'gen_hidden1': tf.Variable(glorot_init([noise_dim, gen_hidden_dim])),#(100 - 256)
    'gen_out': tf.Variable(glorot_init([gen_hidden_dim, image_dim])), #(256 - 784)
    'disc_hidden1': tf.Variable(glorot_init([image_dim, disc_hidden_dim])),#(784 - 256)
    'disc_out': tf.Variable(glorot_init([disc_hidden_dim, 1])),#(256 - 1)
}
biases = {
    'gen_hidden1': tf.Variable(tf.zeros([gen_hidden_dim])),
    'gen_out': tf.Variable(tf.zeros([image_dim])),
    'disc_hidden1': tf.Variable(tf.zeros([disc_hidden_dim])),
    'disc_out': tf.Variable(tf.zeros([1])),
}
定义生成对抗网络
代码语言:javascript
复制
# 定义生成器函数
def generator(x):
    # 输入x是1x100的矩阵,weights['gen_hidden1']是100x256的矩阵,矩阵相乘结果是1x256的矩阵,生成器隐藏层含256个节点
    hidden_layer = tf.matmul(x, weights['gen_hidden1'])
    # biases['gen_hidden1']是1x256的矩阵,生成器隐藏层含256个节点
    hidden_layer = tf.add(hidden_layer, biases['gen_hidden1'])
    # 激活函数 relu
    hidden_layer = tf.nn.relu(hidden_layer)
    # hidden_layer是1x256的矩阵,weights['gen_out']是256x784的矩阵,矩阵相乘结果是1x784的矩阵,生成器输出层含784个节点
    out_layer = tf.matmul(hidden_layer, weights['gen_out'])
    # biases['gen_out']是1x784的矩阵,生成器输出层含784个节点
    out_layer = tf.add(out_layer, biases['gen_out'])
    # 激活函数 sigmoid
    out_layer = tf.nn.sigmoid(out_layer)
    return out_layer


# 定义判别器函数
def discriminator(x):
    # 输入x是生成器生成的1x784的矩阵(生成的图片),weights['disc_hidden1']是784x256的矩阵,矩阵相乘结果是1x256的矩阵,判别器隐藏层含256个节点
    hidden_layer = tf.matmul(x, weights['disc_hidden1'])
    # biases['disc_hidden1']是1x256的矩阵,生成器隐藏层含256个节点
    hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])
    # 激活函数 relu
    hidden_layer = tf.nn.relu(hidden_layer)
    # hidden_layer是1x256的矩阵,weights['disc_out']是256x1的矩阵,矩阵相乘结果是一个数,判别器输出层含1个节点
    out_layer = tf.matmul(hidden_layer, weights['disc_out'])
    # biases['disc_out']是一个数,判别器输出层含1个节点
    out_layer = tf.add(out_layer, biases['disc_out'])
    # 激活函数 sigmoid
    out_layer = tf.nn.sigmoid(out_layer)
    return out_layer

# 构建网络
# 网络输入
gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise') # 生成器 输入噪点 batch*100,none是一个空值,后面赋值batch_size
disc_input = tf.placeholder(tf.float32, shape=[None, image_dim], name='disc_input') # 判别器 输入真实图像 batch*784

# 构建生成器(generator)
gen_sample = generator(gen_input)

# 构建两个判别器(一个是真实图像输入,一个是生成图像)
disc_real = discriminator(disc_input) # 真实图像
disc_fake = discriminator(gen_sample) # 通过生成器生成的图像

# 创建损失函数
# 关于GAN的理论推导,可参见 <a href="#footnote-1"><sup>[1]</sup></a>
gen_loss = -tf.reduce_mean(tf.log(disc_fake)) # 生成器损失函数
disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake)) # 判别器损失函数

# 创建优化器(采用Adam方法),可参见 <a href="#footnote-2"><sup>[2]</sup></a>
optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate)
optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate)

# Training Variables for each optimizer
# By default in TensorFlow, all variables are updated by each optimizer, so we
# need to precise for each one of them the specific variables to update.
# 生成网络的变量
gen_vars = [weights['gen_hidden1'], weights['gen_out'],
            biases['gen_hidden1'], biases['gen_out']]
# 判别网络的变量
disc_vars = [weights['disc_hidden1'], weights['disc_out'],
            biases['disc_hidden1'], biases['disc_out']]

# 创建训练操作
train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)

# 变量全局初始化
init = tf.global_variables_initializer()

GAN的网络结构类似于多层感知机:

训练生成对抗网络
代码语言:javascript
复制
# 开始训练
# 创建一个会话
sess = tf.Session()

# 初始化
sess.run(init)

# 训练
for i in range(1, num_steps+1):
    # 准备数据
    # 拿到下一批次的 MNIST 数据 (仅需要图像, 不需要标签)
    batch_x, _ = mnist.train.next_batch(batch_size) # 判别器输入 真实图像,batch_*784
    # 给生成器生成噪点数据
    z = np.random.uniform(-1., 1., size=[batch_size, noise_dim]) # 生成器输入 噪声,batch*100

    # 训练
    feed_dict = {disc_input: batch_x, gen_input: z} #给placeholder填入值
    _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],
                            feed_dict=feed_dict)
    if i % 2000 == 0 or i == 1:
        print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))

代码语言:javascript
复制
Step 1: Generator Loss: 0.223592, Discriminator Loss: 2.090910
Step 2000: Generator Loss: 4.678916, Discriminator Loss: 0.041115
Step 4000: Generator Loss: 3.605874, Discriminator Loss: 0.068698
Step 6000: Generator Loss: 3.845584, Discriminator Loss: 0.190420
Step 8000: Generator Loss: 4.470613, Discriminator Loss: 0.117488
Step 10000: Generator Loss: 3.813103, Discriminator Loss: 0.146255
Step 12000: Generator Loss: 2.991248, Discriminator Loss: 0.392258
Step 14000: Generator Loss: 3.769275, Discriminator Loss: 0.153639
Step 16000: Generator Loss: 4.366917, Discriminator Loss: 0.206618
Step 18000: Generator Loss: 4.052875, Discriminator Loss: 0.225112
Step 20000: Generator Loss: 3.574747, Discriminator Loss: 0.362798
Step 22000: Generator Loss: 3.760236, Discriminator Loss: 0.188211
Step 24000: Generator Loss: 3.055995, Discriminator Loss: 0.354645
Step 26000: Generator Loss: 3.619049, Discriminator Loss: 0.211489
Step 28000: Generator Loss: 3.523777, Discriminator Loss: 0.273607
Step 30000: Generator Loss: 3.889854, Discriminator Loss: 0.286803
Step 32000: Generator Loss: 3.106094, Discriminator Loss: 0.298111
Step 34000: Generator Loss: 3.548391, Discriminator Loss: 0.343262
Step 36000: Generator Loss: 3.081174, Discriminator Loss: 0.332788
Step 38000: Generator Loss: 2.946176, Discriminator Loss: 0.335102
Step 40000: Generator Loss: 3.078653, Discriminator Loss: 0.465524
Step 42000: Generator Loss: 2.601799, Discriminator Loss: 0.409574
Step 44000: Generator Loss: 3.168177, Discriminator Loss: 0.325075
Step 46000: Generator Loss: 2.601811, Discriminator Loss: 0.428143
Step 48000: Generator Loss: 2.853810, Discriminator Loss: 0.403768
Step 50000: Generator Loss: 2.690175, Discriminator Loss: 0.483180
Step 52000: Generator Loss: 3.278867, Discriminator Loss: 0.375016
Step 54000: Generator Loss: 2.869437, Discriminator Loss: 0.477840
Step 56000: Generator Loss: 2.561056, Discriminator Loss: 0.449300
Step 58000: Generator Loss: 2.814199, Discriminator Loss: 0.484522
Step 60000: Generator Loss: 2.469474, Discriminator Loss: 0.428359
Step 62000: Generator Loss: 2.721684, Discriminator Loss: 0.494090
Step 64000: Generator Loss: 2.491284, Discriminator Loss: 0.654795
Step 66000: Generator Loss: 2.725388, Discriminator Loss: 0.423149
Step 68000: Generator Loss: 2.758215, Discriminator Loss: 0.513224
Step 70000: Generator Loss: 3.072056, Discriminator Loss: 0.481437
测试
代码语言:javascript
复制
# 测试
# 通过训练出的生成网络输入噪点,生成图像
n = 6
canvas = np.empty((28 * n, 28 * n))
for i in range(n):
    # 噪点输入
    z = np.random.uniform(-1., 1., size=[n, noise_dim])
    # 生成图像
    g = sess.run(gen_sample, feed_dict={gen_input: z})
    # 颜色反转便于显示
    g = -1 * (g - 1)
    for j in range(n):
        # 绘制生成的手写体数字
        canvas[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = g[j].reshape([28, 28])

plt.figure(figsize=(n, n))
plt.imshow(canvas, origin="upper", cmap="gray")
plt.show()

参考

[1] [机器之心GitHub项目:GAN完整理论推导与实现,Perfect!]https://www.jiqizhixin.com/articles/2017-10-1-1

[2] [深度学习最全优化方法总结比较(SGD,Adagrad,Adadelta,Adam,Adamax,Nadam)]https://zhuanlan.zhihu.com/p/22252270

[3] [不要怂,就是GAN (生成式对抗网络) (一): GAN 简介]https://www.cnblogs.com/Charles-Wan/p/6238033.html

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-04-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI异构 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 生成对抗网络简介
  • Tensorflow生成对抗网络实现
    • 导入数据集
      • 参数设置
        • 设置每一层的权重与偏置
          • 定义生成对抗网络
            • 训练生成对抗网络
              • 测试
              • 参考
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档