tensorflow 1.01中GAN(生成对抗网络)手写字体生成例子(MINST)的测试

为了更好地掌握GAN的例子,从网上找了段代码进行跑了下,测试了效果。具体过程如下:

代码文件如下:

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np from skimage.io import imsave import os import shutil img_height = 28 img_width = 28 img_size = img_height * img_width to_train = True to_restore = False output_path = "output" # 总迭代次数500 max_epoch = 500 h1_size = 150 h2_size = 300 z_size = 100 batch_size = 256 # generate (model 1) def build_generator(z_prior):     w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)     b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)     h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)     w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)     b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)     h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)     w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)     b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32)     h3 = tf.matmul(h2, w3) + b3     x_generate = tf.nn.tanh(h3)     g_params = [w1, b1, w2, b2, w3, b3]     return x_generate, g_params # discriminator (model 2) def build_discriminator(x_data, x_generated, keep_prob):     # tf.concat     x_in = tf.concat([x_data, x_generated],0)     w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)     b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)     h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)     w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)     b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)     h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)     w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)     b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)     h3 = tf.matmul(h2, w3) + b3     y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))     y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))     d_params = [w1, b1, w2, b2, w3, b3]     return y_data, y_generated, d_params #  def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):     batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5     img_h, img_w = batch_res.shape[1], batch_res.shape[2]     grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)     grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)     img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)     for i, res in enumerate(batch_res):         if i >= grid_size[0] * grid_size[1]:             break         img = (res) * 255         img = img.astype(np.uint8)         row = (i // grid_size[0]) * (img_h + grid_pad)         col = (i % grid_size[1]) * (img_w + grid_pad)         img_grid[row:row + img_h, col:col + img_w] = img     imsave(fname, img_grid) def train():     # load data(mnist手写数据集)     mnist = input_data.read_data_sets('MNIST_data', one_hot=True)     x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data")     z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")     keep_prob = tf.placeholder(tf.float32, name="keep_prob")     global_step = tf.Variable(0, name="global_step", trainable=False)     # 创建生成模型     x_generated, g_params = build_generator(z_prior)     # 创建判别模型     y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob)     # 损失函数的设置     d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))     g_loss = - tf.log(y_generated)     optimizer = tf.train.AdamOptimizer(0.0001)     # 两个模型的优化函数     d_trainer = optimizer.minimize(d_loss, var_list=d_params)     g_trainer = optimizer.minimize(g_loss, var_list=g_params)     init = tf.initialize_all_variables()     saver = tf.train.Saver()     # 启动默认图     sess = tf.Session()     # 初始化     sess.run(init)     if to_restore:         chkpt_fname = tf.train.latest_checkpoint(output_path)         saver.restore(sess, chkpt_fname)     else:         if os.path.exists(output_path):             shutil.rmtree(output_path)         os.mkdir(output_path)     z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)     steps = 60000 / batch_size     for i in range(sess.run(global_step), max_epoch):         for j in np.arange(steps): #         for j in range(steps):             print("epoch:%s, iter:%s" % (i, j))             # 每一步迭代,我们都会加载256个训练样本,然后执行一次train_step             x_value, _ = mnist.train.next_batch(batch_size)             x_value = 2 * x_value.astype(np.float32) - 1             z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)             # 执行生成             sess.run(d_trainer,                      feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})             # 执行判别             if j % 1 == 0:                 sess.run(g_trainer,                          feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})         x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})         show_result(x_gen_val, "output/sample{0}.jpg".format(i))         z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)         x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})         show_result(x_gen_val, "output/random_sample{0}.jpg".format(i))         sess.run(tf.assign(global_step, i + 1))         saver.save(sess, os.path.join(output_path, "model"), global_step=global_step) def test():     z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")     x_generated, _ = build_generator(z_prior)     chkpt_fname = tf.train.latest_checkpoint(output_path)     init = tf.initialize_all_variables()     sess = tf.Session()     saver = tf.train.Saver()     sess.run(init)     saver.restore(sess, chkpt_fname)     z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)     x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})     show_result(x_gen_val, "output/test_result.jpg") if __name__ == '__main__':     if to_train:         train()     else:         test()

按照500次迭代,每次迭代产生一张手写体图片,然后进行判别反馈,这样持续下去,可以看到不同迭代次数的效果。

(第1张)

(第2张)

第10张

第24张

第50张

第140张

第256张

第500张

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Petrichor的专栏

深度学习: 经典 数据集 汇总

官网:www.cs.toronto.edu/~kriz/cifar 介绍:CIFAR-10数据集说明、TensorFlow CNN 测试CIFAR-10数据...

4313
来自专栏企鹅号快讯

用keras搭建3D卷积神经网络

资源: 3D卷积神经网络相关博客:http://blog.csdn.net/lengxiaomo123/article/details/68926778 ker...

7377
来自专栏数值分析与有限元编程

坐标映射

建立等参单元,需要另外一个自然坐标系下的参考单元。对于物理坐标系下的任意一点,在自然坐标系下的参考单元中,有唯一的一个点与之对应;反过来对于自然坐标系下参考单元...

3044
来自专栏YoungGy

MMD_6b_DecisionTree

overview ? construct 构建决策树的时候需要考虑以下问题: 什么时候停止 如果不停止,那么以什么变量的什么特征构建二叉树 如果停止,那么预测的...

2047
来自专栏目标检测和深度学习

手把手教你搭建目标检测器-附代码

1272
来自专栏Pytorch实践

Tensorflow实现部分参数梯度更新

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。本文主要介绍,使用tensorf...

5112
来自专栏chenjx85的技术专栏

超参数搜索——网格搜索和随机搜索

我们在搜索超参数的时候,如果超参数个数较少(三四个或者更少),那么我们可以采用网格搜素,一种穷尽式的搜索方法。

1313
来自专栏GAN&CV

手把手教你创建自己的object detector

本文译自:http://www.hackevolve.com/create-your-own-object-detector/

1771
来自专栏小鹏的专栏

人脸检测——fcn

在上一篇的基础上修改即可:人脸检测——滑动窗口篇(训练和实现) !!!注意:这些是我的调试版本,最优版本不方便公开,但是自己可以查看论文,自行在此基础上修改,...

2128
来自专栏CreateAMind

DeepMind时间序列生成模型

使用vae推理;lstm 时间相关建模;external memory加强记忆;和神经图灵机 可微分计算机等进行比较。

1302

扫码关注云+社区

领取腾讯云代金券