【CVPR2018】如何增强Attention Model的推理能力

【导读】目前Attention Model已经被用到了机器视觉,自然语言理解,语音识别,机器翻译等等各行各业。各种各样的Attention Model也被各种Task使用。如何增强Attention Model的推理能力、在使用Attention Model的时候提升模型性能,成为了一个关键的问题。在本文中,我们介绍一种在CVPR 2018大会提出的方法,可以通过极为简单的改进有效的提升Attention Model的性能。

论文题目:Stacked Latent Attention for Multimodal Reasoning

什么是Attention Model

首先我们用下图的例子简单的重温Attention Model:

给定Hidden State,Attention Model可以学到对输入(图示中为图像)Tensor最相关的Mask,并使用Mask对输入Tensor进行加权和,并将加权和后得到的Content Vector作为Attention Model的输出。换而言之,Attention Model可以学到给定输入中最重要的部分,从而对输入进行“总结”。

增强Attention Model的性能的方法——Stacked Attention Model

接下来我们介绍一种非常常用的增强Attention Model的性能的方法:Stacked Attention Model。顾名思义,就是简单的拼接(Stack)多个Attention Model,将前一个AttentionModel的输出作为下一个Attention Model的输入。具体实现如下图所示:

在今年刚刚召开的CVPR大会中,研究者对这种常用的增强Attention Model的方法进行了探索,提出了上图中方法的缺陷,并通过极为简单的改进有效地增强了Attention Model的推理性能:

研究者发现,在Attention Model“总结”输入Tensor的同时,造成了信息瓶颈(Information Bottleneck),该信息瓶颈会导致模型性能下降。同时因Attention Model的SoftMax集中在Pathway上而造成了梯度弥散,进而导致在使用多层Attention Model时模型难以优化(Optimize)。

研究者提出,通过简单将多层Attention Model的隐变量(Latent State)连接(Concat)起来(上图绿色虚线),就可以解决信息瓶颈和梯度弥散问题。如上图所示,在没有绿色虚线的情况下,模型仅仅将多层Attention Model叠加起来,此方法不但1)在每两个Attention Model之间造成了信息瓶颈,同时2)因主要Pathway中有多个SoftMax,而造成梯度弥散。

文章提出,仅仅通过增加上图中的绿色虚线,将前一层Attention Model中的隐变量(LatentState) 连接(Concat)到下一个Attention Model中,就可以1)打破信息瓶颈,同时2)通过提供了新的Pathway避开原Pathway中的多个SoftMax,从而缓解梯度弥散,进而3)提升模型性能。

实验表明,当将多层Attention Models的隐变量连接起来,随着简单增加所连接的Attention Model的数量,整体模型性能得到了显著的提升。同时梯度弥散问题得到了明显的缓解:

该文章的更多细节可以参考:

http://openaccess.thecvf.com/content_cvpr_2018/papers/Fan_Stacked_Latent_Attention_CVPR_2018_paper.pdf

-END-

原文发布于微信公众号 - 专知(Quan_Zhuanzhi)

原文发表时间:2018-07-02

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏大数据挖掘DT机器学习

NLP真实项目:利用这个模型能够通过商品评论去预测一个商品的销量

前言 由于是日语项目,用到的分词软件等,在中文任务中需要替换为相应的中文分词软件。例如结巴分词 : https://github.com/fxsjy/jieb...

64811
来自专栏专知

【论文推荐】最新6篇行人重识别相关论文—深度空间特征重构、生成对抗网络、图像生成、系列实战、图像-图像域自适应方法、行人检索

【导读】专知内容组整理了最近六篇行人重识别(Person Re-identification)相关文章,为大家进行介绍,欢迎查看! 1. Deep Spatia...

6926
来自专栏AI科技大本营的专栏

技能 | 三次简化一张图: 一招理解LSTM/GRU门控机制

作者 | 张皓 引言 RNN是深度学习中用于处理时序数据的关键技术, 目前已在自然语言处理, 语音识别, 视频识别等领域取得重要突破, 然而梯度消失现象制约着R...

3468
来自专栏深度学习自然语言处理

梯度下降法理论与实践

理论基础 现在比如有两个参数的损失函数 ? 我们的目的是使之最小也就是得到能够使J函数最小的 ? , ? ,公式表示为: ? 我们画出当 ? ?...

3079
来自专栏专知

【论文推荐】最新六篇行人再识别(ReID)相关论文—和谐注意力网络、时序残差学习、评估和基准、图像生成、三元组、对抗属性-图像

【导读】专知内容组整理了最近六篇行人再识别(Person Re-Identification)相关文章,为大家进行介绍,欢迎查看! 1. Harmonious ...

7095
来自专栏用户2442861的专栏

Scene Text Localization & Recognition Resources

https://github.com/chongyangtao/Awesome-Scene-Text-Recognition

1552
来自专栏杨熹的专栏

5 分钟入门 Google 最强NLP模型:BERT

BERT (Bidirectional Encoder Representations from Transformers)

4943
来自专栏杨熹的专栏

双向 LSTM

本文结构: 为什么用双向 LSTM 什么是双向 LSTM 例子 ---- 为什么用双向 LSTM? 单向的 RNN,是根据前面的信息推出后面的,但有时候只看前面...

6696
来自专栏大学生计算机视觉学习DeepLearning

python实现gabor滤波器提取纹理特征 提取指静脉纹理特征 指静脉切割代码

2515
来自专栏机器学习原理

机器学习篇(2)——最小二乘法概念最小二乘法

前言:主要介绍了从最小二乘法到 概念 顾名思义,线性模型就是可以用线性组合进行预测的函数,如图: ? image.png 公式如下: ? i...

6025

扫码关注云+社区

领取腾讯云代金券