元学习可以被定义为一种序列到序列的问题, 在现存的方法中,元学习器的瓶颈是如何去吸收同化利用过去的经验。 注意力机制可以允许在历史中精准摘取某段具体的信息。
Simple Neural Attentive Learner (SNAIL) 组合时序卷积和 soft-attention, 前者从过去的经验整合信息,后者精确查找到某些特殊的信息。
时序卷积 (TCN) 是有因果前后关系的,即在下一时间步生成的值仅仅受之前的时间步影响。 TCN 可以提供更直接,高带宽的传递信息的方法,这允许它们基于一个固定大小的时序内容进行更复杂的计算。 但是,随着序列长度的增加,卷积膨胀的尺度会随之指数增加,需要的层数也会随之对数增加。 因此这种方法对于之前输入的访问更粗略,且他们的有限的能力和位置依赖并不适合元学习器, 因为元学习器应该能够利用增长数量的经验,而不是随着经验的增加,性能会被受限。
soft-attention 可以实现从超长的序列内容中获取准确的特殊信息。 它将上下文作为一种无序的关键值存储,这样就可以基于每个元素的内容进行查询。 但是,位置依赖的缺乏(因为是无序的)也是一个缺点。
TCN 和 soft-attention 可以实现功能互补: 前者提供高带宽的方法,代价是受限于上下文的大小,后者可以基于不确定的可能无限大的上下文提供精准的提取。 因此,SNAIL 的构建使用二者的组合:使用时序卷积去处理用注意力机制提取过的内容。 通过整合 TCN 和 attention,SNAIL 可以基于它过去的经验产出高带宽的处理方法且不再有经验数量的限制。 通过在多个阶段使用注意力机制,端到端训练的 SNAIL 可以学习从收集到的信息中如何摘取自己需要的信息并学习一个恰当的表示。
在元学习中每个任务
都是独立的, 其输入为
,输出为
,损失函数是
, 一个转移分布
,和一个输出长度
。 一个元学习器(由
参数化)建模分布:
给定一个任务的分布
, 元学习器的目标是最小化它的期待损失:
元学习器被训练去针对从
中抽样出来的任务 (或一个 mini-batches 的任务) 优化这个期望损失。 在测试阶段,元学习器在新任务分布
上被评估。
两个时序卷积层(橙色)和一个因果关系层(绿色)的组合是 SNAIL 的基础结构, 如图1所示。 在监督学习设置中, SNAIL 接收标注样本
和末标注的
, 然后基于标注样本对
进行预测。
图1 SNAIL 基础结构示意图。
对于构建 SNAIL 使用了两个主要模块: Dense Block 和 Attention Block。
图1 SNAIL 中的 Dense Block 和 Attention Block。(a) Dense Block 应用因果一维卷积,然后将输出连接到输入。TC Block 应用一系列膨胀率呈指数增长的 Dense Block。(b) Attention Block 执行(因果)键值查找,并将输出连接到输入。
Densen Block 用了一个简单的因果一维卷积(空洞卷积), 其中膨胀率 (dilation)为
和卷积核数量
([1] 对于所有的实验中设置卷积核的大小为2), 最后合并结果和输入。 在计算结果的时候使用了一个门激活函数。 具体算法如下:
, number of filers
):
,
), CausalConv (inputs,
,
)
TC Block 由一系列 dense block 组成,这些 dense block 的膨胀率
呈指数级增长,直到它们的接受域超过所需的序列长度。具体代码实现时,对序列是需要填充的为了保持序列长度不变。具体算法如下:
, number of filers
):
do 1. inputs = DenseBlock (inputs,
,
)
Attention Block [1] 中设计成 soft-attention 机制, 公式为:
, value size
):
), affine (inputs,
)
)
)
表1 SNAIL 在 Omniglot 上的分类结果。
Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 20-way 5-shot |
---|---|---|---|---|
Santoro et al. (2016) | 82.8 % \% % | 94.9 % \% % | – | – |
Koch (2015) | 97.3 % \% % | 98.4 % \% % | 88.2 % \% % | 97.0 % \% % |
Vinyals et al. (2016) | 98.1 % \% % | 98.9 % \% % | 93.8 % \% % | 98.5 % \% % |
Finn et al. (2017) | 98.7 ± \pm ± 0.4 % \% % | 99.9 ± \pm ± 0.3 % \% % | 95.8 ± \pm ± 0.3 % \% % | 98.9 ± \pm ± 0.2 % \% % |
Snell et al. (2017) | 97.4 % \% % | 99.3 % \% % | 96.0 % \% % | 98.9 % \% % |
Munkhdalai & \& & Yu (2017) | 98.9 % \% % | – | 97.0 % \% % | – |
SNAIL | 99.07 ± \pm ± 0.16 % \% % | 99.78 ± \pm ± 0.09 % \% % | 97.64 ± \pm ± 0.30 % \% % | 99.36 ± \pm ± 0.18 % \% % |
94.9
––Koch (2015)97.3
98.4
88.2
97.0
Vinyals et al. (2016)98.1
98.9
93.8
98.5
Finn et al. (2017)98.7
0.4
99.9
0.3
95.8
0.3
98.9
0.2
Snell et al. (2017)97.4
99.3
96.0
98.9
Munkhdalai
Yu (2017)98.9
–97.0
–SNAIL99.07
0.16
99.78
0.09
97.64
0.30
99.36
0.18
表1 SNAIL 在 miniImageNet 上的分类结果。
Method | 5-way 1-shot | 5-way 5-shot |
---|---|---|
Vinyals et al. (2016) | 43.6 % \% % | 55.3 % \% % |
Finn et al. (2017) | 48.7 ± \pm ± 1.84 % \% % | 63.1 ± \pm ± 0.92 % \% % |
Ravi & \& & Larochelle (2017) | 43.4 ± \pm ± 0.77 % \% % | 60.2 ± \pm ± 0.71 % \% % |
Snell et al. (2017) | 46.61 ± \pm ± 0.78 % \% % | 65.77 ± \pm ± 0.70 % \% % |
Munkhdalai & \& & Yu (2017) | 49.21 ± \pm ± 0.96 % \% % | – |
SNAIL | 55.71 ± \pm ± 0.99 % \% % | 68.88 ± \pm ± 0.92 % \% % |
55.3
Finn et al. (2017)48.7
1.84
63.1
0.92
Ravi
Larochelle (2017)43.4
0.77
60.2
0.71
Snell et al. (2017)46.61
0.78
65.77
0.70
Munkhdalai
Yu (2017)49.21
0.96
–SNAIL55.71
0.99
68.88
0.92
[1] A Simple Neural Attentive Meta-Learner
Relation Network (RN) 使用有监督度量学习估计样本点之间的距离, 根据新样本点和过去样本点之间的距离远近,对新样本点进行分类。
RN 包括两个组成部分:嵌入模块和关系模块,且两者都是通过有监督学习得到的。 嵌入模块从输入数据中提取特征,关系模块根据特征计算任务之间的距离, 判断任务之间的相似性,找到过去可借鉴的经验进行加权平均。 RN 结构如图1所示。
图1 RN 结构。
嵌入模块记为
,关系模块记为
, 支持集中的样本记为
, 查询集中的样本记为
。
和
输入
, 产生特征映射
和
。
将
和
结合, 得到
。
输入
, 得到
范围内的标量, 表示
和
之间的相似性,记为关系得分
。
和
相似度越高,
越大。
其中,
用来判断
和
是否属于同一类别。 当
时,
, 当
时,
。
嵌入模块和关系模块的选取有很多种,包括卷积网络、残差网络等。
图2给出了 [1] 中使用的 RN 模型结构。
图2 RN 模型结构。
3 滤波器进行卷积,一个归一化层、一个 ReLU 非线性层。
2 的最大池化层,后边两个卷积块没有池化层。
2 的最大池化层。
。
表1 RN 在 Omniglot 上的分类结果。
Model | Fine Tune | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 20-way 5-shot |
---|---|---|---|---|---|
MANN | N | 82.8 % \% % | 94.9 % \% % | – | – |
CONVOLUTIONAL SIAMESE NETS | N | 96.7 % \% % | 98.4 % \% % | 88.0 % \% % | 96.5 % \% % |
CONVOLUTIONAL SIAMESE NETS | Y | 97.3 % \% % | 98.4 % \% % | 88.1 % \% % | 97.0 % \% % |
MATCHING NETS | N | 98.1 % \% % | 98.9 % \% % | 93.8 % \% % | 98.5 % \% % |
MATCHING NETS | Y | 97.9 % \% % | 98.7 % \% % | 93.5 % \% % | 98.7 % \% % |
SIAMESE NETS WITH MEMORY | N | 98.4 % \% % | 99.6 % \% % | 95.0 % \% % | 98.6 % \% % |
NEURAL STATISTICIAN | N | 98.1 % \% % | 99.5 % \% % | 93.2 % \% % | 98.1 % \% % |
META NETS | N | 99.0 % \% % | – | 97.0 % \% % | – |
PROTOTYPICAL NETS | N | 98.8 % \% % | 99.7 % \% % | 96.0 % \% % | 98.9 % \% % |
MAML | Y | 98.7 ± \pm ± 0.4 % \% % | 99.9 ± \pm ± 0.1 % \% % | 95.8 ± \pm ± 0.3 % \% % | 98.9 ± \pm ± 0.2 % \% % |
RELATION NET | N | 99.6 ± \pm ± 0.2 % \% % | 99.8 ± \pm ± 0.1 % \% % | 97.6 ± \pm ± 0.2 % \% % | 99.1 ± \pm ± 0.1 % \% % |
94.9
––CONVOLUTIONAL SIAMESE NETSN96.7
98.4
88.0
96.5
CONVOLUTIONAL SIAMESE NETSY97.3
98.4
88.1
97.0
MATCHING NETSN98.1
98.9
93.8
98.5
MATCHING NETSY97.9
98.7
93.5
98.7
SIAMESE NETS WITH MEMORYN98.4
99.6
95.0
98.6
NEURAL STATISTICIANN98.1
99.5
93.2
98.1
META NETSN99.0
–97.0
–PROTOTYPICAL NETSN98.8
99.7
96.0
98.9
MAMLY98.7
0.4
99.9
0.1
95.8
0.3
98.9
0.2
RELATION NETN99.6
0.2
99.8
0.1
97.6
0.2
99.1
0.1
表1 RN 在 miniImageNet 上的分类结果。
Model | FT | 5-way 1-shot | 5-way 5-shot |
---|---|---|---|
MATCHING NETS | N | 43.56 ± \pm ± 0.84 % \% % | 55.31 ± \pm ± 0.73 % \% % |
META NETS | N | 49.21 ± \pm ± 0.96 % \% % | – |
META-LEARN LSTM | N | 43.44 ± \pm ± 0.77 % \% % | 60.60 ± \pm ± 0.71 % \% % |
MAML | Y | 48.70 ± \pm ± 1.84 % \% % | 63.11 ± \pm ± 0.92 % \% % |
PROTOTYPICAL NETS | N | 49.42 ± \pm ± 0.78 % \% % | 68.20 ± \pm ± 0.66 % \% % |
RELATION NET | N | 50.44 ± \pm ± 0.82 % \% % | 65.32 ± \pm ± 0.70 % \% % |
0.84
55.31
0.73
META NETSN49.21
0.96
–META-LEARN LSTMN43.44
0.77
60.60
0.71
MAMLY48.70
1.84
63.11
0.92
PROTOTYPICAL NETSN49.42
0.78
68.20
0.66
RELATION NETN50.44
0.82
65.32
0.70
[1] Learning to Compare: Relation Network for Few-Shot Learning
Prototypical Network (PN) 利用支持集中每个类别提供的少量样本, 计算它们的嵌入中心,作为每一类样本的原型 (Prototype), 接着基于这些原型学习一个度量空间, 使得新的样本通过计算自身嵌入与这些原型的距离实现最终的分类。
在 few-shot 分类任务中, 假设有
个标记的样本
, 其中,
是
维的样本特征向量,
是相应的标签。
表示第
类样本的集合。
PN 计算每个类的
维原型向量
, 计算的函数为
, 其中
为可学习参数。 原型向量
即为嵌入空间中该类的所有 支持集样本点的均值向量
给定一个距离函数
, 不包含任何可训练的参数, PN 通过在嵌入空间中对距离进行 softmax 计算, 得到一个针对
的样本点的概率分布
新样本点的特征离类别中心点越近, 新样本点属于这个类别的概率越高; 新样本点的特征离类别中心点越远, 新样本点属于这个类别的概率越低。
通过在 SGD 中最小化第
类的负对数似然函数
来推进学习
PN 示意图如图1所示。
图1 PN 示意图。
Input: Training set
, where each
.
denotes the subset of
containing all elements
such that
.
Output: The loss
for a randomly generated training episode.
in
do
in
do
in
do
其中,
是训练集中的样本个数;
是训练集中的类个数;
是每个 episode 选出的类个数;
是每类中 support set 的样本个数;
是每类中 query set 的样本个数;
表示从集合
中随机选出
个元素。
表1 PN 在 Omniglot 上的分类结果。
Model | Dist. | Fine Tune | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 20-way 5-shot |
---|---|---|---|---|---|---|
MATCHING NETWORKS | Cosine | N | 98.1 % \% % | 98.9 % \% % | 93.8 % \% % | 98.5 % \% % |
MATCHING NETWORKS | Cosine | Y | 97.9 % \% % | 98.7 % \% % | 93.5 % \% % | 98.7 % \% % |
NEURAL STATISTICIAN | - | N | 98.1 % \% % | 99.5 % \% % | 93.2 % \% % | 98.1 % \% % |
MAML | - | N | 98.7 % \% % | 99.9 % \% % | 95.8 % \% % | 98.9 % \% % |
PROTOTYPICAL NETWORKS | Euclid. | N | 98.8 % \% % | 99.7 % \% % | 96.0 % \% % | 98.9 % \% % |
98.9
93.8
98.5
MATCHING NETWORKSCosineY97.9
98.7
93.5
98.7
NEURAL STATISTICIAN-N98.1
99.5
93.2
98.1
MAML-N98.7
99.9
95.8
98.9
PROTOTYPICAL NETWORKSEuclid.N98.8
99.7
96.0
98.9
表1 PN 在 miniImageNet 上的分类结果。
Model | Dist. | Fine Tune | 5-way 1-shot | 5-way 5-shot |
---|---|---|---|---|
BASELINE NEAREST NEIGHBORS | Cosine | N | 28.86 ± \pm ± 0.54 % \% % | 49.79 ± \pm ± 0.79 % \% % |
MATCHING NETWORKS | Cosine | N | 43.40 ± \pm ± 0.78 % \% % | 51.09 ± \pm ± 0.71 % \% % |
MATCHING NETWORKS (FCE) | Cosine | N | 43.56 ± \pm ± 0.84 % \% % | 55.31 ± \pm ± 0.73 % \% % |
META-LEARNER LSTM | - | N | 43.44 ± \pm ± 0.77 % \% % | 60.60 ± \pm ± 0.71 % \% % |
MAML | - | N | 48.70 ± \pm ± 1.84 % \% % | 63.15 ± \pm ± 0.91 % \% % |
PROTOTYPICAL NETWORKS | Euclid. | N | 49.42 ± \pm ± 0.78 % \% % | 68.20 ± \pm ± 0.66 % \% % |
0.54
49.79
0.79
MATCHING NETWORKSCosineN43.40
0.78
51.09
0.71
MATCHING NETWORKS (FCE)CosineN43.56
0.84
55.31
0.73
META-LEARNER LSTM-N43.44
0.77
60.60
0.71
MAML-N48.70
1.84
63.15
0.91
PROTOTYPICAL NETWORKSEuclid.N49.42
0.78
68.20
0.66
[1] Prototypical Networks for Few-shot Learning
Matching Network (MN) 结合了度量学习 (Metric Learning) 与记忆增强神经网络 (Memory Augment Neural Networks), 并利用注意力机制与记忆机制加速学习,同时提出了 set-to-set 框架, 使得 MN 能够为新类产生合理的测试标签,且不用网络做任何改变。
将支持集
映射到一个分类器
, 给定一个测试样本
,
定义一个关于输出
的概率分布,即
其中,
被网络参数化。 因此,当给定一个新的支持集
进行小样本学习时, 只需使用
定义的网络来预测每个测试示例
的适当标签分布
即可。
模型以最简单的形式计算
上的概率:
上式本质是将一个输入的新类描述为支持集中所有类的一个线性组合, 结合了核密度估计KDE(
可以看做是一种核密度估计)和 KNN 。 其中,
表示支持集中样本类别数,
是注意力机制, 类似 attention 模型中的核函数, 用来度量
和训练样本
的匹配度。
的计算基于新样本数据与支持集中的样本数据的嵌入表示的余弦相似度以及softmax函数:
其中,
表示余弦相似度,
与
表示施加在测试样本与训练样本上的嵌入函数 (Embedding Function)。
如果注意力机制是
上的核, 则上式类似于核密度估计器。 如果选取合适的距离度量以及适当的常数, 从而使得从
到
的注意力机制为 0 , 则上式等价于 KNN 。
图1是 MN 的网络结构示意图。
图1 MN 示意图。
为了增强样本嵌入的匹配度, [1] 提出了 Full Context Embeeding (FCE) 方法: 支持集中每个样本的嵌入应该是相互独立的, 而新样本的嵌入应该受支持集样本数据分布的调控, 其嵌入过程需要放在整个支持集环境下进行, 因此 [1] 采用带有注意力的 LSTM 网络对新样本进行嵌入。
在对余弦注意力定义时, 每个已知标签的输入
通过 CNN 后的 embedding , 因此
是独立的,前后没有关系, 然后与
进行逐个对比, 并没有考虑到输入任务
改变 embedding
的方式, 而
应该是受
影响的。 为了实现这个功能,[1] 采用了双向 LSTM 。
在通过嵌入函数
和
处理后, 输出再次经过循环神经网络进一步加强 context 和个体之间的关系。
其中,
是相关的上下文,
为网络的 timesteps 。
因此,经过
步后的状态为:
特征提取器可采用常见的 VGG 或 Inception 网络, [1] 设计了一种简单的四级网络结构用于图像分类任务的特征提取, 每级网络由一个 64 通道的 3
3 卷积层,一个批规范化层, 一个 ReLU 激活层和一个 2
2 的最大池化层构成。 然后将最后一层输出的特征输入到 LSTM 网络中得到最终的特征映射
和
。
中所有图片
(假设有
个)和目标图片
(假设有 1 个) 全部通过 CNN 网络,获得它们的浅层变量表示。
个)浅层变量全部输入到 BiLSTM 中,获得
个输出, 然后使用余弦距离判断前
个输出中每个输出与最后一个输出之间的相似度。
中的标签信息
求解目标图片
的类别标签
。
表1 MN 在 Omniglot 上的分类结果。
Model | Matching Fn | Fine Tune | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 20-way 5-shot |
---|---|---|---|---|---|---|
PIXELS | Cosine | N | 41.7 % \% % | 63.2 % \% % | 26.7 % \% % | 42.6 % \% % |
BASELINE CLASSIFIER | Cosine | N | 80.0 % \% % | 95.0 % \% % | 69.5 % \% % | 89.1 % \% % |
BASELINE CLASSIFIER | Cosine | Y | 82.3 % \% % | 98.4 % \% % | 70.6 % \% % | 92.0 % \% % |
BASELINE CLASSIFIER | Softmax | Y | 86.0 % \% % | 97.6 % \% % | 72.9 % \% % | 92.3 % \% % |
MANN (NO CNOV) | Cosine | N | 82.8 % \% % | 94.9 % \% % | – | – |
CONVOLUTIONAL SIAMESE NET | Cosine | Y | 96.7 % \% % | 98.4 % \% % | 88.0 % \% % | 96.5 % \% % |
CONVOLUTIONAL SIAMESE NET | Cosine | Y | 97.3 % \% % | 98.4 % \% % | 88.1 % \% % | 97.0 % \% % |
MATCHING NETS | Cosine | N | 98.1 % \% % | 98.9 % \% % | 93.8 % \% % | 98.5 % \% % |
MATCHING NETS | Cosine | Y | 97.9 % \% % | 98.7 % \% % | 93.5 % \% % | 98.7 % \% % |
63.2
26.7
42.6
BASELINE CLASSIFIERCosineN80.0
95.0
69.5
89.1
BASELINE CLASSIFIERCosineY82.3
98.4
70.6
92.0
BASELINE CLASSIFIERSoftmaxY86.0
97.6
72.9
92.3
MANN (NO CNOV)CosineN82.8
94.9
––CONVOLUTIONAL SIAMESE NETCosineY96.7
98.4
88.0
96.5
CONVOLUTIONAL SIAMESE NETCosineY97.3
98.4
88.1
97.0
MATCHING NETSCosineN98.1
98.9
93.8
98.5
MATCHING NETSCosineY97.9
98.7
93.5
98.7
表1 MN 在 miniImageNet 上的分类结果。
Model | Matching Fn | Fine Tune | 5-way 1-shot | 5-way 5-shot |
---|---|---|---|---|
PIXELS | Cosine | N | 23.0 % \% % | 26.6 % \% % |
BASELINE CLASSIFIER | Cosine | N | 36.6 % \% % | 46.0 % \% % |
BASELINE CLASSIFIER | Cosine | Y | 36.2 % \% % | 52.2 % \% % |
BASELINE CLASSIFIER | Cosine | Y | 38.4 % \% % | 51.2 % \% % |
MATCHING NETS | Cosine | N | 41.2 % \% % | 56.2 % \% % |
MATCHING NETS | Cosine | Y | 42.4 % \% % | 58.0 % \% % |
MATCHING NETS | Cosine (FCE) | N | 44.2 % \% % | 57.0 % \% % |
MATCHING NETS | Cosine (FCE) | Y | 46.6 % \% % | 60.0 % \% % |
26.6
BASELINE CLASSIFIERCosineN36.6
46.0
BASELINE CLASSIFIERCosineY36.2
52.2
BASELINE CLASSIFIERCosineY38.4
51.2
MATCHING NETSCosineN41.2
56.2
MATCHING NETSCosineY42.4
58.0
MATCHING NETSCosine (FCE)N44.2
57.0
MATCHING NETSCosine (FCE)Y46.6
60.0
的增长,每次迭代的计算量也会随之快速增长,导致计算速度降低。