前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用于文本生成的GAN模型

用于文本生成的GAN模型

作者头像
智能生信
发布2021-10-08 16:20:56
3.9K0
发布2021-10-08 16:20:56
举报
文章被收录于专栏:智能生信

生成对抗网络(GAN)包含两个部分:一个是生成器(generator),一个是判别模型(discriminator)。生成器的任务是生成看起来逼真与原始数据相似的样本。判别器的任务是判断生成模型生成的样本是真实的还是伪造的。换句话说,生成器要生成能骗过判别器的实例,而判别器要从真假混合的样本中揪出由生成器生成的伪造样本。生成器和判别器的训练过程是一个对抗博弈的过程,最后博弈的结果是在最理想的状态下,生成器可以生成足以“以假乱真”的样本。

一、什么是GAN

生成对抗网络(GAN)包含两个部分:一个是生成器(generator),一个是判别模型(discriminator)。生成器的任务是生成看起来逼真与原始数据相似的样本。判别器的任务是判断生成模型生成的样本是真实的还是伪造的。换句话说,生成器要生成能骗过判别器的实例,而判别器要从真假混合的样本中揪出由生成器生成的伪造样本。生成器和判别器的训练过程是一个对抗博弈的过程,最后博弈的结果是在最理想的状态下,生成器可以生成足以“以假乱真”的样本。

图1. GAN的基本结构

二、GAN在文本生成中遇到的困境

传统的GAN只适用于连续型数据的生成,对于离散型数据效果不佳。文本数据不同于图像数据,文本数据是典型的离散型数据。图像数据在计算机中被表示为矩阵,矩阵中的数值可微分并且直接反映出图像本身的属性,从图像矩阵到图像不需要采样;而文本数据在计算机中表示为one-hot编码的向量,这个向量中有n项是0,只有一项是1,这一项代表词库中某个词,我们在神经网络中操作时,最后得到的都是一个某个词向量每个维度的概率分布而非标准的one-hot编码的向量,只能将这个输出结果过渡到one-hot向量再从词库中查找对应的词,这个操作被称为采样。

神经网络的优化方法大多是基于梯度的,GAN在面对离散型数据时,判别器无法把梯度反向传播给生成器——判别器得到的是生成器采样后的结果,在判别器参数微调后,可能输出优化了一点点但还不足以改变采样的结果,例如生成器网络的最后结果为[0.1,0.1,0.8]经过采样输出的one-hot词向量为[0,0,1],而参数微调后生成器网络最后的结果变为了[0.1,0.2,0.7]经过采样输出的one-hot词向量依旧为[0,0,1],生成器便会再一次将相同答案输入给判别器,这样判别器给出的评价就会毫无意义,生成器的训练也会失去方向。

为了解决GAN在面对离散型数据无法将梯度反向传播给生成器的问题,人们提出了三种方案:1.判别器直接获取生成器的输出;2.使用Gumbel-softmax代替softmax;3.通过强化学习来绕过采样带来的问题。其中第一种方法虽然可以绕过采样操作,避免采样带来的梯度无法反传的问题,但生成的数据与真实数据差距太大,判别器可以很轻易地分辨出生成的数据与真实的数据,因为此时生成的数据是离散的向量,而真实数据是one-hot向量,判别器可以很容易分辨两者的差异,此时GAN是难以训练的。

三、几种用于生成文本的GAN模型

3.1 Seq-GAN

SeqGAN的核心思想是将GAN与强化学习的Policy Gradient算法结合到一起,出发点是意识到了标准的GAN在处理离散数据时会遇到的困难:生成器难以梯度更新,判别器难以评估非完整序列。对于生成器难以梯度更新问题,作者把整个GAN看作一个强化学习系统,用Policy Gradient算法更新Generator的参数;对于判别器难以评估非完整序列问题,作者借鉴了蒙特卡洛树搜索的思想,对任意时刻的非完整序列都可以进行评估。

图2. SeqGAN结构

SeqGAN结构如图2所示,已经存在的红色圆点称为现在的状态(state),要生成的下一个红色圆点称作动作(action),因为D需要对一个完整的序列评分,所以就是用MCTS(蒙特卡洛树搜索)将每一个动作的各种可能性补全,D对这些完整的序列产生reward,回传给G,通过增强学习更新G。这样就是用Reinforcement learning的方式,训练出一个可以产生下一个最优的action的生成网络。

3.2 LeakGAN

基于GAN生成文本的方法大多数场景是生成短文本,对于长文本来说还是存在很多挑战。先前的GAN中判别器的标量指导信号是稀疏的,只有在完整生成文本后才可用,缺少生成过程中的文本结构的中间信息。当生成的文本样本长度很长时效果不好。LeakGAN通过泄露判别器提取的特征作为引导信号,指导生成器更好地生成长文本。同时,借助分层强化学习从判别器向生成器提供更丰富的信息。

图3. LeakGAN结构

3.3 RelGAN

RelGAN由三个主要组件组成:基于关系记忆的生成器、Gumbel-Softmax用于离散数据上训练GAN、鉴别器中嵌入多个表示为生成器提供更多信息。在样品质量和多样性方面,RelGAN相比于其他GAN模型具有一定优势。并且,RelGAN可以通过单个可调参数控制样本质量和多样性之间的权衡。

图4. RelGAN生成器中的注意力机制

t时刻的记忆单元Mt和矩阵Wq相乘得到Q矩阵,Mt与t时刻的输入的词向量xt拼接后分别于WK、Wv相乘得到K矩阵和V矩阵,Q矩阵和K矩阵的转置相乘后经过sofmax函数得到注意力权重,再将注意力全中与V矩阵相乘得到更新后的记忆单元。

图5. RelGAN判别器

判别器结构如图5所示,为了从多方面捕获输入特征,词向量通过多个词向量表示层输入CNN网络,这样子就输出多个判别器损失,综合多个方面的判别器损失,得到最终的损失输出,这样子,可以从多个方面综合评估词向量的差异,提供多样性和更加丰富的信息指导判别器的训练。


参考文献

https://arxiv.org/abs/1609.05473

https://arxiv.org/abs/1709.08624v2

https://arxiv.org/abs/1908.07269


本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-09-12,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 智能生信 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档