N-Shot Learning:用最少的数据训练最多的模型

原标题 | N-Shot Learning: Learning More with Less Data

作 者 | Heet Sankesara

翻 译 | 天字一号(郑州大学)、邺调(江苏科技大学)

审 校 | 唐里、鸢尾、Pita

人工智能将引领新世纪 - Andrew NG

如果将AI比作电力的话,那么数据就是创造电力的煤。

不幸的是,正如我们看到可用煤是消耗品一样,许多 AI 应用程序可供访问的数据很少或根本就没有数据。

新技术已经弥补了物质资源的不足;同样需要新的技术来允许在数据很少时,保证程序的正常运行。这是正在成为一个非常受欢迎的领域,核心问题:N-shot Learning

1.

N-Shot Learning

你可能会问,什么是shot?好问题,shot只用一个样本来训练,在N-shot学习中,我们有N个训练的样本。术语“小样本学习”中的“小”通常在0-5之间,也就是说,训练一个没有样本的模型被称为 zero-shot ,一个样本就是 one-shot 学习,以此类推。

1-1 为什么需要N-Shot?

我们在 ImageNet 中的分类错误率已经小于 4% 了,为什么我们需要这个?

首先,ImageNet 的数据集包含了许多用于机器学习的示例,但在医学影像、药物发现和许多其他 AI 可能至关重要的领域中并不总是如此。典型的深度学习架构依赖于大量数据训练才能获得足够可靠的结果。例如,ImageNet 需要对数百张热狗图像进行训练,然后才能判断一幅新图像准确判断是否为热狗。一些数据集,就像7月4日庆祝活动后的冰箱缺乏热狗一样,是非常缺乏图像的。

机器学习有许多案例数据是都非常稀缺,这就是N-Shot技术的用武之地。我们需要训练一个包含数百万甚至数十亿个参数(全部随机初始化)的深度学习模型,但可用于训练的图像不超过 5 个图像。简单地说,我们的模型必须使用非常有限的热狗图像进行训练。

要处理像这个这样复杂的问题,我们首先需要清楚N-Shot的定义。

在 N-shot学习领域中,每K个类别,我们标记了 n 个示例,这 N*K个总示例被我们称为支持集 S 。我们还必须对查询集 Q 进行分类,其中每个示例位于其中一个 K 类中。N-shot 学习有三个主要子领域:zero-shot learning、one-shot learning和小样本学习,每个领域都值得关注。

1-2 Zero-Shot learning

对我来说,最有趣的子领域是Zero-shot learning,该领域的目标是不需要一张训练图像,就能够对未知类别进行分类。

没有任何数据可以利用的话怎么进行训练和学习呢?

想一下这种情况,你能对一个没有见过的物体进行分类吗?

是的,如果你对这个物体的外表、属性和功能有充足的信息的话,你是可以实现的。想一想,当你还是一个孩子的时候,是怎么理解这个世界的。在了解了火星的颜色和晚上的位置后,你可以在夜空中找到火星。或者你可以通过了解仙后座在天空中"基本上是一个畸形的'W'"这个信息中识别仙后座。

根据今年NLP的趋势,Zero-shot learning 将变得更加有效(https://blog.floydhub.com/ten-trends-in-deep-learning-nlp/#9-zero-shot-learning-will-become-more-effective)。

计算机利用图像的元数据执行相同的任务。元数据只不过是与图像关联的功能。以下是该领域的几篇论文,这些论文取得了优异的成绩。

  • Learning to Compare: Relation Network for Few-Shot Learning(https://arxiv.org/pdf/1711.06025v2.pdf)
  • Learning Deep Representations of Fine-Grained Visual Descriptions(https://arxiv.org/pdf/1605.05395v1.pdf)
  • Improving zero-shot learning by mitigating the hubness problem(https://arxiv.org/abs/1412.6568v3)

1-3 One-Shot Learning

在one-shot learning中,我们每个类别只有一个示例。现在的任务是使用一个影像进行训练,最终完成将测试影像划分为各个类。为了实现这一目标,目前已经出现了很多不同的架构,例如Siamese Neural Networks(https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf),它带来了重大进步,并达到了卓越的结果。然后紧接着是matching networks(https://arxiv.org/pdf/1606.04080.pdf),这也帮助我们在这一领域实现了巨大的飞跃。

  • Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(https://arxiv.org/pdf/1703.03400v3.pdf)
  • One-shot Learning with Memory-Augmented Neural Networks(https://arxiv.org/pdf/1605.06065v1.pdf)
  • Prototypical Networks for Few-shot Learning(https://arxiv.org/pdf/1703.05175v2.pdf)

1-4 小样本学习

小样本学习只是one-shot learning 的灵活应用。在小样本学习中,我们有多个训练示例(通常为两到五个图像,尽管上述one-shot learning中的大多数模型也可用于小样本学习)。

在2019年计算机视觉和模式识别会议上,介绍了 Meta-Transfer Learning for Few-Shot Learning(https://arxiv.org/pdf/1812.02391v3.pdf)。这一模式为今后的研究开创了先例;它给出了最先进的结果,并为更复杂的元迁移学习方法铺平了道路。

这些元学习和强化学习算法中有许多都是与典型的深度学习算法相结合,并产生了显著的结果。原型网络是最流行的深度学习算法之一,并经常用于小样本学习 。

在本文中,我们将使用原型网络完成小样本学习,并了解其工作原理。

2.

原型网络背后的思想

上图为原型网络函数的示意图。编码器将图像进行编码映射到嵌入空间(黑圈)中的矢量中,支持图像用于定义原型(星形)。利用原型和编码查询图像之间的距离进行分类。图源:https://www.semanticscholar.org/paper/Gaussian-Prototypical-Networks-for-Few-Shot-on-Fort/feaecb5f7a8d29636650db7c0b480f55d098a6a7/figure/1

与典型的深度学习体系结构不同,原型网络不直接对图像进行分类,而是通过在度量空间(https://en.wikipedia.org/wiki/Metric_space)中寻找图像之间的映射关系。

对于任何需要复习数学的人来说,度量空间都涉及"距离"的概念。它没有一个可区分的"起源"点。相反,在度量空间中,我们只计算一个点与另一个点的距离。因此,这里缺少了矢量空间中加法和标量乘法(因为与矢量不同,点仅表示坐标,添加两个坐标或缩放坐标毫无意义!)请查看此链接,详细了解矢量空间和度量空间之间的差异:https://math.stackexchange.com/questions/114940/what-is-the-difference-between-metric-spaces-and-vector-spaces。

现在,我们已经学习了这一背景,我们可以开始了解原型网络是怎样不直接对图像进行分类,而是通过在度量空间中寻找图像之间的映射关系。如上图所示,同一类的图像经过编码器的映射之后,彼此之间的距离非常接近,而不同类的图像之间具有较长的距离。这意味着,每当给出新示例时,网络只需检查与新示例的图像最近的集合,并将该示例图像分到其相应的类。原型网络中将图像映射到度量空间的基础模型可以被称为"Image2Vector"模型,这是一种基于卷积神经网络 (CNN) 的体系结构。

现在,对于那些对 CNN 不了解的人,您可以在此处阅读更多内容:

  • 深度学习的课程:https://blog.floydhub.com/best-deep-learning-courses-updated-for-2019/.
  • 深度学习书籍:https://blog.floydhub.com/best-deep-learning-books-updated-for-2019/.
  • 快速学习和应用 Building Your First ConvNethttps://blog.floydhub.com/building-your-first-convnet/

2-1 原型网络简介

简单地说,他们的目标是训练分类器。然后,该分类器可以对在训练期间不可用的新类进行概括,并且只需要每个新类的少量示例。因此,训练集包含一组类的图像,而我们的测试集包含另一组类的图像,这与前一组完全不相关。在该模型中,示例被随机分为支持集和查询集。

2-2 原型网络综述

很少有镜头原型ck被计算为每个类的嵌入式支持示例的平均值。编码器映射新图像(x)并将其分类到最接近的类,如上图中的c2(图源:https://arxiv.org/pdf/1703.05175.pdf)。

在少镜头学习的情况下,训练迭代被称为一个片段。一个小插曲不过是我们训练网络一次,计算损失并反向传播错误的一个步骤。在每一集中,我们从训练集中随机选择NC类。对于每一类,我们随机抽取ns图像。这些图像属于支持集,学习模型称为ns-shot模型。另一个随机采样的nq图像属于查询集。这里nc、ns和nq只是模型中的超参数,其中nc是每次迭代的类数,ns是每个类的支持示例数,nq是每个类的查询示例数。

之后,我们通过“image2vector”模型从支持集图像中检索d维点。该模型利用图像在度量空间中的对应点对图像进行编码。对于每个类,我们现在有多个点,但是我们需要将它们表示为每个类的一个点。因此,我们计算每个类的几何中心,即点的平均值。之后,我们还需要对查询图像进行分类。

为此,我们首先需要将查询集中的每个图像编码为一个点。然后,计算每个质心到每个查询点的距离。最后,预测每个查询图像位于最靠近它的类中。一般来说,模型就是这样工作的。

但现在的问题是,这个“image2vector”模型的架构是什么?

2-3 Image2Vector 向量

论文汇总 Image2Vector 向量的结构

对于所有实际应用中,一般都会使用 4-5 CNN 模块。如上图所示,每个模块由一个 CNN 层组成,然后是批处理规范化,然后是 ReLu 激活函数,最后通向最大池层。在所有模块之后,剩余的输出将被展平并返回。这是本文中使用的网络结构(https://arxiv.org/pdf/1703.05175v2.pdf),您可以使用任何任何你喜欢的体系结构。有必要知道,虽然我们称之为"Image2Vector"模型,但它实际上将图像转换为度量空间中的 64 维的点。要更好地了解差异,请查看 math stack exchange(https://math.stackexchange.com/questions/645672/what-is-the-difference-between-a-point-and-a-vector)。

2-4 Loss函数

负log概率的原理,图源:https://ljvmiranda921.github.io/notebook/2017/08/13/softmax-and-the-negative-log-likelihood/#nll

现在,已经知道了模型是如何工作的,您可能更想知道我们将如何计算损失函数。我们需要一个足够强大的损失函数,以便我们的模型能够快速高效地学习。原型网络使用log-softmax损失,这只不过是对 softmax 损失取了对数。当模型无法预测正确的类时,log-softmax 的效果会严重惩罚模型,而这正是我们需要的。要了解有关损失函数的更多情况,请访问此处。这里是关于 softmax 和 log-softmax 的很好的讨论。

2-5 数据集概览

Omniglot数据集中的部分示例(图源:https://github.com/brendenlake/omniglot)

该网络在 Omniglot 数据集(https://github.com/brendenlake/omniglot)上进行了训练。Omniglot 数据集是专门为开发更类似于人类学习的算法而设计。它包含 50个不同的字母表,共计1623 个不同的手写字符。为了增加类的数量,所有图像分别旋转 90、180 和 270 度,每次旋转后的图像都当做一个新类。因此,类的总数达到 了6492(1,623 + 4)类别。我们将 4200 个类别的图像作为训练数据,其余部分则用于测试。对于每个集合,我们根据 64 个随机选择的类中的每个示例对模型进行了训练。我们训练了模型 1 小时,获得了约 88% 的准确率。官方文件声称,经过几个小时的训练和调整一些参数,准确率达到99.7%。

是时候亲自动手实践了!

您可以通过访问以下链接轻松运行代码:

代码地址:https://github.com/Hsankesara/Prototypical-Networks

运行地址:https://floydhub.com/run?template=https://github.com/Hsankesara/Prototypical-Networks

让我们深入学习一下代码!

class Net(nn.Module):
    """
    Image2Vector CNN which takes the image of dimension (28x28x3) and return column vector length 64
    """
    def sub_block(self, in_channels, out_channels=64, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU()
            torch.nn.MaxPool2d(kernel_size=2))
        return block    
    def __init__(self):
        super(Net, self).__init__()
        self.convnet1 = self.sub_block(3)
        self.convnet2 = self.sub_block(64)
        self.convnet3 = self.sub_block(64)
        self.convnet4 = self.sub_block(64)
    def forward(self, x):
        x = self.convnet1(x)
        x = self.convnet2(x)
        x = self.convnet3(x)
        x = self.convnet4(x)
        x = torch.flatten(x, start_dim=1)
        return x

以上的代码是 Image2Vector CNN结构的一个实现。它的输入图像的维度为28*28*3,返回特征向量的长度为64。

class PrototypicalNet(nn.Module):
    def __init__(self, use_gpu=False):
        super(PrototypicalNet, self).__init__()
        self.f = Net()
        self.gpu = use_gpu        if self.gpu:
            self.f = self.f.cuda()
    
    def forward(self, datax, datay, Ns,Nc, Nq, total_classes):
        """
        Implementation of one episode in Prototypical Net
        datax: Training images
        datay: Corresponding labels of datax
        Nc: Number  of classes per episode
        Ns: Number of support data per class
        Nq:  Number of query data per class
        total_classes: Total classes in training set
        """
        k = total_classes.shape[0]
        K = np.random.choice(total_classes, Nc, replace=False)
        Query_x = torch.Tensor()
        if(self.gpu):
            Query_x = Query_x.cuda()
        Query_y = []
        Query_y_count = []
        centroid_per_class  = {}
        class_label = {}
        label_encoding = 0
        for cls in K:
            S_cls, Q_cls = self.random_sample_cls(datax, datay, Ns, Nq, cls)
            centroid_per_class[cls] = self.get_centroid(S_cls, Nc)
            class_label[cls] = label_encoding
            label_encoding += 1
            Query_x = torch.cat((Query_x, Q_cls), 0) # Joining all the query set together
            Query_y += [cls]
            Query_y_count += [Q_cls.shape[0]]
        Query_y, Query_y_labels = self.get_query_y(Query_y, Query_y_count, class_label)
        Query_x = self.get_query_x(Query_x, centroid_per_class, Query_y_labels)
        return Query_x, Query_y    
    def random_sample_cls(self, datax, datay, Ns, Nq, cls):
        """
        Randomly samples Ns examples as support set and Nq as Query set
        """
        data = datax[(datay == cls).nonzero()]
        perm = torch.randperm(data.shape[0])
        idx = perm[:Ns]
        S_cls = data[idx]
        idx = perm[Ns : Ns+Nq]
        Q_cls = data[idx]
        if self.gpu:
            S_cls = S_cls.cuda()
            Q_cls = Q_cls.cuda()
        return S_cls, Q_cls    
    def get_centroid(self, S_cls, Nc):
        """
        Returns a centroid vector of support set for a class
        """
        return torch.sum(self.f(S_cls), 0).unsqueeze(1).transpose(0,1) / Nc    
    def get_query_y(self, Qy, Qyc, class_label):
        """
        Returns labeled representation of classes of Query set and a list of labels.
        """
        labels = []
        m = len(Qy)
        for i in range(m):
            labels += [Qy[i]] * Qyc[i]
        labels = np.array(labels).reshape(len(labels), 1)
        label_encoder = LabelEncoder()
        Query_y = torch.Tensor(label_encoder.fit_transform(labels).astype(int)).long()
        if self.gpu:
            Query_y = Query_y.cuda()
        Query_y_labels = np.unique(labels)
        return Query_y, Query_y_labels    
    def get_centroid_matrix(self, centroid_per_class, Query_y_labels):
        """
        Returns the centroid matrix where each column is a centroid of a class.
        """
        centroid_matrix = torch.Tensor()
        if(self.gpu):
            centroid_matrix = centroid_matrix.cuda()
        for label in Query_y_labels:
            centroid_matrix = torch.cat((centroid_matrix, centroid_per_class[label]))
        if self.gpu:
            centroid_matrix = centroid_matrix.cuda()
        return centroid_matrix    
    def get_query_x(self, Query_x, centroid_per_class, Query_y_labels):
        """
        Returns distance matrix from each Query image to each centroid.
        """
        centroid_matrix = self.get_centroid_matrix(centroid_per_class, Query_y_labels)
        Query_x = self.f(Query_x)
        m = Query_x.size(0)
        n = centroid_matrix.size(0)
        # The below expressions expand both the matrices such that they become compatible with each other in order to calculate L2 distance.
        centroid_matrix = centroid_matrix.expand(m, centroid_matrix.size(0), centroid_matrix.size(1)) # Expanding centroid matrix to "m".
        Query_matrix = Query_x.expand(n, Query_x.size(0), Query_x.size(1)).transpose(0,1) # Expanding Query matrix "n" times
        Qx = torch.pairwise_distance(centroid_matrix.transpose(1,2), Query_matrix.transpose(1,2))
        return Qx

上面的代码片段是原型网中单个结构的实现。如果你有任何疑问,只需在评论中询问或在这里创建一个问题,非常欢迎您的参与和评论。

网络概述。图源:https://youtu.be/wcKL05DomBU

代码的结构与解释算法的格式相同。我们为原型网络函数提供以下输入:输入图像数据、输入标签、每次迭代的类数(即Nc)、每个类的支持示例数(即Ns)和每个类的查询示例数(即Nq)。函数返回Queryx,它是从每个查询点到每个平均点的距离矩阵,Queryy 是包含与Queryx 对应的标签的向量。Queryy 存储Queryx 的图像实际所属的类。在上面的图像中,我们可以看到,使用3个类,即Nc =3,并且对于每个类,总共有5个示例用于训练,即Ns=5。上面的s表示包含这15个(Ns*Nc )图像的支持集,X 表示查询集。注意,支持集和查询集都通过f,它只不过是我们的“image2vector”函数。它在度量空间中映射所有图像。让我们一步一步地把整个过程分解。

首先,我们从输入数据中随机选择Nc 类。对于每个类,我们使用random_sample_cls函数从图像中随机选择一个支持集和一个查询集。在上图中,s是支持集,x是查询集。现在我们选择了类(C1、C2 C3 ),我们通过“image2vector”模型传递所有支持集示例,并使用get_centroid函数计算每个类的质心。在附近的图像中也可以观察到这一点。每个质心代表一个类,将用于对查询进行分类。

网络中的质心计算。图源:https://youtu.be/wcKL05DomBU

在计算每个类的质心之后,我们现在必须预测其中一个类的查询图像。为此,我们需要与每个查询对应的实际标签,这些标签是使用get_query_y函数获得的。Queryy 是分类数据,该函数将该分类文本数据转换为一个热向量,该热向量在列点对应的图像实际所属的行标签中仅为“1”,在列中为“0”。

之后,我们需要对应于每个Queryx图像的点来对其进行分类。我们使用“image2vector”模型得到这些点,现在我们需要对它们进行分类。为此,我们计算Queryx中每个点到每个类中心的距离。这给出了一个矩阵,其中索引 ij 表示与第i 个查询图像对应的点到第j 类中心的距离。我们使用get_query_x函数构造矩阵并将矩阵保存在Queryx 变量中。在附近的图像中也可以看到同样的情况。对于查询集中的每个示例,将计算它与C1、C2 C3 之间的距离。在这种情况下,X最接近C2 ,因此我们可以说X被预测属于C2 类。

以编程方式,我们可以使用一个简单的ARMmin函数来做同样的事情,即找出图像被预测的类。然后使用预测类和实际类计算损失并反向传播错误。

如果你想使用经过训练的模型,或者只需要重新训练自己,这里是我的实现。您可以使用它作为API,并使用几行代码来训练模型。你可以在这里找到这个网络。

3.

资源列表

这里有些资源可以帮你更全面的了解本文内容:

  • One Shot Learning with Siamese Networks using Keras(https://sorenbouma.github.io/blog/oneshot/)
  • One-Shot Learning: Face Recognition using Siamese Neural Network(https://towardsdatascience.com/one-shot-learning-face-recognition-using-siamese-neural-network-a13dcf739e)
  • Matching network official implementation(https://github.com/AntreasAntoniou/MatchingNetworks)
  • Prototypical Network official implementation.(https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch)
  • Meta-Learning for Semi-Supervised Few-Shot Classification(https://arxiv.org/abs/1803.00676)

4.

局限性

尽管原型网络的结果不错,但它们仍然有局限性。首先是缺乏泛化,它在Omniglot数据集上表现很好,因为其中的所有图像都是一个字符的图像,因此共享一些相似的特征。然而,如果我们试图用这个模型来分类不同品种的猫,它不会给我们准确的结果。猫和字符图像几乎没有共同的特征,可以用来将图像映射到相应度量空间的共同特征的数量可以忽略不计。

原型网络的另一个限制是只使用均值来确定中心,而忽略了支持集中的方差,这在图像有噪声的情况下阻碍了模型的分类能力。利用高斯原网络(https://arxiv.org/abs/1708.02735)类中的方差,利用高斯公式对嵌入点进行建模,克服了这一局限性。

5.

结论

小概率学习是近年来研究的热点之一。有许多使用原型网络的新方法,比如这种元学习方法,效果很好。研究人员也在探索强化学习,这也有很大的潜力。这个模型最好的地方在于它简单易懂,并且能给出令人难以置信的结果。

via https://blog.floydhub.com/n-shot-learning/

本文分享自微信公众号 - AI研习社(okweiwu)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-09-24

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券