首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

动手实现会写数字的神经网络—半监督学习和生成式对抗网络介绍

AiTechYun

编辑:yuxiangyu

在1889年,梵高画了这个美丽的艺术品:星月夜。如今,我的GAN模型只使用20%的标签数据,学会了画MNIST数字!它是怎么实现的?让我们动手做做看。

半监督学习

大多数深度学习分类器需要大量的标签样本才能很好地泛化,但获取这些数据是的过程往往很艰难。为了解决这个限制,半监督学习被提出,它是利用少量标记数据和大量未标记数据的分类技术。许多机器学习研究人员发现,将未标记数据与少量标记数据结合使用时,可以显著提高学习准确性。在半监督学习中,GAN(生成式对抗网络)表现出了很大的潜力,其中分类器可以用很少的标签数据取得良好的表现。

GAN的背景

GAN是深度生成模型的一种。它们特别有趣,因为它们没有明确表示数据所在空间的概率分布。而是通过从中抽取样本,提供了一些不直接与这种概率分布不直接相关的方法。

普通GAN架构

GAN的基本原理是在两个“玩家”之间建立一场比赛:

生成器(G):取随机噪声z作为输入并输出图像x。它的参数被调整以让它产生的假图像从判别器中获得高分。

判别器(D):获取图像X作为输入,并输出一个反映了它对于这是否是真实图像的信心得分。它的参数被调整为:当有真实图像馈送时反馈高分,并且发生器馈送假图像时会反馈低分。

现在,让我们来稍微讨论一下GAN最重要的应用之一,半监督学习。

直觉

普通判别器架构只有一个输出神经元用于分类R / F概率(对/错)。我们同时训练两个网络并在训练完成后丢弃判别器,因为它仅用于改进发生器。

对于半监督任务,除了R / F神经元之外,判别器现在将具有10个用于MNIST数字分类的神经元。而且,这次他们的角色会改变,我们可以在训练后丢弃生成器,其唯一目标是生成未标记的数据以提高判别器的性能。

现在判别器成为了11个类的分类器,其中1个神经元(R / F神经元)代表假数据输出,另外10个代表具有类的实际数据。你必须牢记以下几点:

当来自数据集的真的无监督(或者说标签)数据被馈送时,要保证R / F神经元输出标签= 0

当来自发生器的假的无监督数据被馈送时,要保证R / F神经元输出标签= 1

当真实有监督数据被馈送时,要保证R / F输出标签= 0并且相应的标签输出= 1

不同数据来源的组合将有助于判别器的分类更精确。

架构

现在我们动手进行编码。

判别器

下面的架构与DCGAN论文中提出的架构类似。我们使用跨卷积(strided convolutions)来减少特征向量的维度,而不是任何池化层,并且为所有层应用一系列的leaky_relu,dropout和BN来稳定学习。输入层和最后一层中BN被舍弃(为了特征匹配)。最后,我们执行全局平均池化(Global Average Pooling)以取得特征向量空间维度上的平均值。这可以将张量维度压缩为单个值。在扁平化了特征之后,为了多类输出增加一个11个类的稠密层和softmax激活函数。

发生器

发生器架构旨在模仿判别器的空间输出。使用部分跨卷积来增加表示的空间维度。噪声的四维张量的输入z被馈送,它经过转置卷积,relu,BN(输出层除外)和dropout操作。最后,tanh激活将输出图像映射到(-1,1)范围内。

模型损失

我们首先通过将实际标签附加为零来准备整个批次的扩展标签。这样做是为了在标记数据馈送时,R / F神经元的输出为0。未标记数据的判别器损失可以被认为是一个二元sigmoid损失,通过将R / F神经元输出为1声明假图像,而真实图像输出为0。

发生器损失是fake_image损失与特征匹配损失的组合,前者错误的将R / F神经元输出断言为0,后者惩罚训练数据上一组特征的平均值与生成样本中这组特征的平均值之间的平均绝对误差。

训练

训练图像从[batch_size,28,28,1]调整为[batch_size,64,64,1]以适应发生器和判别器架构。计算损失,准确性和生成样本,并观察每个周期的改进。

结论

由于GPU的限制,训练已完成5个周期和20%的 labeled_rate。想要获得更好的结果,建议使用较小的label_rate的训练更多周期。

完整代码:https://github.com/raghav64/SemiSuper_GAN/blob/master/SSGAN.py

训练结果

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180408A0YMUJ00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券