首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

听说GAN很高大上,其实就这么简单

本文使用的tensorflow版本:1.4

tensorflow安装:pip install tensorflow

1、先来目睹一下效果吧

这篇文章讲解了如何使用GAN来生成我们的手写数字,我们首先来看看生成的效果吧:

10轮:

50轮:

100轮:

200轮:

可以看到,在10轮的时候,我们的Generator生成的图片非常模糊,几乎是无法用肉眼来分别数字的,到了第50轮的时候,已经初见雏形了,有一些数字比如1、4、5这些都可以很清楚的分辨出来,不过还并不是十分完美。到了100轮的时候,像数字8和9这些也基本能准确的生成了,而到200轮的时候,除去个别的以外,基本上都能正确的手写出来了,由于时间的原因,没有继续训练下去,如果大家感兴趣,可以训练更多轮,看看更好的效果。

2、思路解析

设定参数

本文设定的参数是,图片的大小是28*28,这是mnist图片的标准大小,后面是一些保存模型的设定。我们总共的训练轮数是500轮。在我们的Generator和Discriminator中,我们设定的是一个简单的有两层隐藏层的全链接神经网络。对于Generator来说,输入的的大小是[batch_size,z_size],第一个隐藏层有150个神经元,第二个隐藏层有300个神经元,输出的大小就是图片的size28*28。而对于Discriminator来说,我们的输入的大小是[batch_size * 2,img_size],因为我们要掺杂真实的img和Generator生成的img。第一个隐藏层有300个神经元,第二个隐藏层有150个神经元,输出层只有1个数,表示该图片为真实图片的概率。关于神经网络的结构我们会在后面详细讲解。

创建Generator

刚才也讲到了,对于Generator来说,输入的的大小是[batch_size,z_size],第一个隐藏层有150个神经元,第二个隐藏层有300个神经元,输出的大小就是图片的size28*28。总的来说,经过Generator,由[batch_size,z_size] 变为 [batch_size,img_size]

创建Discrminator

而对于Discriminator来说,我们的输入的大小是[batch_size * 2,img_size],因为我们要掺杂真实的img和Generator生成的img。第一个隐藏层有300个神经元,第二个隐藏层有150个神经元,输出层只有1个数,表示该图片为真实图片的概率。要注意,我们的输入和输出是严格对应的,所以对于输出成的输出h3来说,前batch_size个代表着对真实图片的判别概率,而后batch_size代表着对Generator生成的图片的判别概率。这里是用了一个tf.slice()函数,之前没有接触过,故在这里做一下记录:

1,函数原型 tf.slice(inputs,begin,size,name='')

2,用途:从inputs中抽取部分内容

inputs:可以是list,array,tensor

begin:n维列表,begin[i] 表示从inputs中第i维抽取数据时,相对0的起始偏移量,也就是从第i维的begin[i]开始抽取数据

size:n维列表,size[i]表示要抽取的第i维元素的数目

所以可以看到,最终y_data保存的是真实图片的判别概率,这些值要越接近于1越好,而y_generated保存的是Generator生成的图片的判别概率,这些值要越接近于0越好。

保存图片

关于保存图片的代码,我们这里就不讲了,这也不是重点,大家有兴趣的话可以研究下:

设定训练目标

对于Discriminator来说,他希望能够使二分类的结果越准确越好,即越能准确判别真实图片和Generator生成的图片越好,所以我们这里使用类似于逻辑回归中的损失函数,而对于Generator来说,它希望的是Discriminator无法分辨它生成的图片,所以它希望Discriminator能将它生成的图片越多的分类为真实图片,所以我们设定的训练目标如下:

训练

训练其实很简单,每次得到batch_size大小的样本,先训练一次Discriminator,然后再训练我们的Generator

3、完整代码

本文涉及到的完整代码如下:

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180124G0TE2Y00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券