前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >度量学习:使用多类N对损失改进深度度量学习

度量学习:使用多类N对损失改进深度度量学习

作者头像
深度学习思考者
发布2023-10-17 16:26:12
3500
发布2023-10-17 16:26:12
举报

@度量学习系列

Author: 码科智能

使用多类N对损失改进深度度量学习

度量学习是ReID任务中常用的方式之一,今天来看下一篇关于如何改进度量学习的论文。来自2016年NeurIPS上的一篇论文,被引用超过900次。

论文:Improved Deep Metric Learning with Multi-class N-pair Loss Objective. 链接:论文.

1. 对比损失和三重损失

度量学习
度量学习
  • 令 x ∈ X 为输入数据,f ∈ {1, …, L} 为其输出标签。
  • f+ 和 f- 分别表示 f 的正例和负例,意思是 f 和 f+ 属于同一类,f- 属于 f 的不同类。
1.1. 对比损失
  • 对比损失将成对的样本作为网络模型的输入,通过训练网络来预测两个输入是否来自同一类。
在这里插入图片描述
在这里插入图片描述
  • 其中 m 是一个边距参数,它强制来自不同类的样本之间的距离大于 m。
1.2. 三重损失
  • Triplet loss 与 contrastive loss 具有相似的原理,但其由三元组组成,每个三元组由一个查询、一个正例(同查询一个类别)和一个负例组成:
在这里插入图片描述
在这里插入图片描述
  • 与contrastive loss相比,triplet loss只需要正例与查询样本的相似度和负例与查询点的相似度之差大于margin即可(即上述的边距参数m)。
  • Triplet loss 的作用是拉近正样本 f+ ,同时推开负样本 f- 。
  • 对比损失或三元组损失已用于许多应用,例如人脸识别和图像检索,例如DrLIM、DeepFace、DeepID2、FaceNet。但此类框架通常存在收敛速度慢和局部最优值差的问题,部分原因是损失函数在每次更新时仅使用一个负样本,而不与其他负样本交互。
  • Hard negative data mining 可以缓解这个问题,但是 hard negative example search 在网络训练中带来额外的时间开销。

2. (N+1)-Tuplet Loss for Multiple Negative Examples

在这里插入图片描述
在这里插入图片描述
  • 如上所示,(N+1)-tuplet loss 根据它们与输入样本的相似性,一次性推送 N-1 个负样本。
  • f+ 是 f 的正例(蓝色圆圈),{f2, …, fN-1} 是负例(粉色圆圈)。 (N+1)-tuplet 损失为:
在这里插入图片描述
在这里插入图片描述
  • 当 N=2 时,对应的 (2+1)-tuplet loss 与 triplet loss 非常相似,因为每对输入和正例只有一个负例:
在这里插入图片描述
在这里插入图片描述
  • 当 N>2 时,进一步论证了 (N+1)-tuplet loss 相对于 triplet loss 的优势。 根据理想 (L+1)-tuplet 损失的分配函数估计,将 (N+1)-tuplet 损失与三重损失进行比较,其中 (L+1)-tuplet 损失与每个负类的单个样本相结合,可以写成如下:
在这里插入图片描述
在这里插入图片描述
  • 回想一下,L 是类别的总数,上面的等式类似于多类逻辑损失(即 softmax 损失)。在监督学习里指的是这个数据集一共有多少类别,比如CV的ImageNet数据集有1000类,L就是1000。在度量学习中每个样本都应该有一个类别,那么在扩大数据规模时,比如当向量的维度是几百万的时候,计算复杂度是相当高的。
  • 为了克服这个问题,提出了一种高效的批量构建方法,它只需要 2N 个示例而不是 (N+1)N 来构建长度为 N+1 的 N 个元组。

3. N-pair Loss as Efficient Batch Construction Method

在这里插入图片描述
在这里插入图片描述
  1. Triplet Loss:对于一个f,有一个f+和一个f-。 Batch size N,一个batch需要N个f,有N个f+和N个f-。
  2. (N+1)-Tuplet Loss:对于一个f,有一个f+和N-1个f-。 总共有 N+1 个例子。 当 SGD 的 batch size 为 N 时,一次更新有 N(N+1) 个样本要通过 f。由于每个批次要评估的示例数量以二次方方式增长,因此为非常深的卷积网络扩展训练再次变得不切实际。
  3. N-pair-mc 损失:多类 N-pair 损失 (N-pair-mc),可以表示为:
在这里插入图片描述
在这里插入图片描述
  • 提出的 N-pair-mc 损失是一个新颖的损失,由两个不可或缺的组成部分组成:(N+1)-tuplet 损失,作为构建块损失函数,以及 N-pair 构造,作为实现高度可扩展训练的关键。这意味着每个 f 的每个正 f+ 将变成另一个 f 的 f-,如上图 © 所示。

4. 难负类挖掘和正则化

  • 难负数据挖掘被认为是许多基于三元组的距离度量学习算法的重要组成部分。在这里,提出了负“类”挖掘,而不是负“实例”挖掘,后者以相对有效的方式贪婪地选择负类。
  • N-pair loss的负类挖掘可以按如下方式执行:
    1. Evaluate Embedding Vectors:随机选择大量的输出类C;对于每个类,随机传递一些(一个或两个)示例来提取它们的嵌入向量。
    2. 选择负类:从步骤 1 的 C 个类中随机选择一个类。接下来,贪婪地添加一个违反三重态约束的新类。选定的数量直到我们达到 N 个类别数。当出现平局时,我们随机选择一个平局类。
    3. 完成 N 对:从步骤 2 中选择的每个类中抽取两个示例。
    4. 此外,L2 范数正则化用于将嵌入向量的 L2 范数正则化为较小的。

5. 人脸验证和识别的实验结果

  • 人脸验证和识别是判断两张人脸图像是否为相同身份的问题(验证)和从具有许多负样本的图库中识别相同身份的人脸图像的问题(识别)。
  • 网络在 WebFace 数据库上进行训练,该数据库由来自 10,575 个身份的 494,414 张图像组成,并且使用不同度量学习目标训练的嵌入网络的质量在 Labeled Faces in the Wild (LFW) 数据库上进行评估。
在这里插入图片描述
在这里插入图片描述
  • 上述几个指标分别为LFW 数据集上的平均验证准确度 (MRF)、Rank-1 准确度和DIR@FAR=1% 开集识别率
  • Triplet loss 模型显示了 95.88% 的验证准确率,但在识别任务上表现不佳。N-pair-mc 损失模型显着提高了性能。 此外,通过将 N 增加到 320,可以观察到额外的改进,获得 98.33% 的验证、90.17% 的封闭集和 71.76% 的开放集识别精度。

6. N-pair-mc Loss 代码

代码语言:javascript
复制
// N-pair loss
import torch
import torch.nn.functional as F

class NPairMCLoss(torch.nn.Module):
    def __init__(self, margin=0.1):
        super(NPairMCLoss, self).__init__()
        self.margin = margin

    def forward(self, anchors, positives, negatives):
        # 计算anchor和positive之间的距离
        pos_distance = F.pairwise_distance(anchors, positives)
        
        # 计算anchor和negative之间的距离
        neg_distance = F.pairwise_distance(anchors, negatives)

        # 计算损失函数
        loss = torch.mean(torch.relu(pos_distance - neg_distance + self.margin))
        return loss
代码语言:javascript
复制
// 调用示例

# 创建NPairMCLoss对象
loss_fn = NPairMCLoss(margin=0.1)

# 假设有一批输入数据 anchors, positives, negatives
anchors = torch.randn(16, 128)
positives = torch.randn(16, 128)
negatives = torch.randn(16, 128)

# 计算损失
loss = loss_fn(anchors, positives, negatives)

# 打印损失值
print("Loss:", loss.item())
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2023-05-31,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 使用多类N对损失改进深度度量学习
    • 1. 对比损失和三重损失
      • 1.1. 对比损失
      • 1.2. 三重损失
    • 2. (N+1)-Tuplet Loss for Multiple Negative Examples
      • 3. N-pair Loss as Efficient Batch Construction Method
        • 4. 难负类挖掘和正则化
          • 5. 人脸验证和识别的实验结果
            • 6. N-pair-mc Loss 代码
            相关产品与服务
            数据库
            云数据库为企业提供了完善的关系型数据库、非关系型数据库、分析型数据库和数据库生态工具。您可以通过产品选择和组合搭建,轻松实现高可靠、高可用性、高性能等数据库需求。云数据库服务也可大幅减少您的运维工作量,更专注于业务发展,让企业一站式享受数据上云及分布式架构的技术红利!
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档