首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

GAN生成对抗网络代码解析

上一次我们写了生成对抗网络的工作原理,俗话说,学而不编则罔,编而不学则殆,跑起来才能加深对算法的理解。

在跑之前,首先需要先装上python(废话),tensorflow (https://www.tensorflow.org/install/),装完之后试一下

看能不能出现“Hello, TensorFlow!”成功了就可以正式进入GAN了!

首先第一步是导入数据包,说明此前你应该已经预先安装好了numpy和matplotlib,此外,我们训练所用的数据库MNIST (http://yann.lecun.com/exdb/mnist/)也能通过调用一个tensorflow的函数read_data_sets直接导入(不用手动下载)。

这一步如果成功的话,运行会出现:

就是MNIST被提取出来了,这4个文件依次是,训练图、训练标签、测试图、测试标签。

如果你想看看MNIST里面的数据图长啥样,可以用next_batch来调用一个batch用PyPlot来看看:

由于随机调用,所以每次运行会出现的数字都不同:

接下来我们就可以开始构建判别器Discriminator和生成器Generator了。

在这里有一个小tip要注意就是最开头那段if语句,因为小姐姐没加它的时候出现报错,意思是参数不能重复使用,可是怎么可能,这些权重参数下面都要反复用来计算,所以需要把你要重复使用的参数都包括进去才行继续往下跑(生成器同此)。

生成器就可以看作是反卷积的过程,判别器输入2维或者3维的像素矩阵,输出一个概率,生成器则是反过来——将一个多维的噪音向量输出为一张28*28的像素图(但其实是28*28*1,因为数字只有灰度1一个维度,但是tensorflow通常能处理3通道的RGB像素图),可以看到最后一层加了tf.sigmoid激活函数,它的作用是将灰度转化成黑或者白来输出图像。

生成器构造完了,我们来看看未经训练之前的噪音图像什么样子的:

看一下输出,噪音真的就是噪音本尊:

好,终于要开始训练了,论文告诉我们说,判别器和生成器各自有自己的损失函数,我们同时训练它们,使得生成器能生成更像正样本的图片,并使得判别器能更准确地判断出正负样本图片。

然后构造优化器,这里使用的是Adam梯度下降,同时让判别器对真实图和噪音图分别训练,以方便分别调整步长:

最后,GAN初版不好训练(常常崩),可以用tensorboard同时监控损失函数和训练图的变化,甚至还能画出神经网络的拓扑结构:

使用方法就是先在终端激活tensorflow

然后运行命令

再然后打开网页

http://localhost:6000

就能看到了。

好,万事俱备,只欠训练。友情提示,训练要很久(GPU至少3小时,CPU。。。30个小时吧),而且容易崩(GAN本身的结构决定的),祝福你成功。

开始训练的界面是这样子的:

Tensorboard里面也能看到损失在逐步下降:

最后的最后,其实可以下载训练好的权重来测试 (https://github.com/jonbruner/generative-adversarial-networks/blob/master/pretrained-model/pretrained_gan.ckpt),把它下载放在你的本地文件夹里,然后运行:

就能看到美丽的小数字们啦,啦啦啦啦。

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180202G01HBS00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券