Improved Deep Metric Learning with Multi-class N-pair Loss Objective

Improved Deep Metric Learning with Multi-class N-pair Loss Objective

来源:

  • NIPS’2016
  • NEC Laboratories America

文章目录

  • Improved Deep Metric Learning with Multi-class N-pair Loss Objective
    • Distance Metric Learning
    • Deep Metric Learning with Multiple Negative Examples
      • N-pair loss for efficient deep metric learning
    • 总结
    • 参考

找到这篇论文是因为看了淘宝搜索出品的论文Rethinking the Role of Pre-ranking in Large-scale E-Commerce 1,文中就提到了传统的list-wise损失 不适用于列表中存在多个正样本的场景。从样本构造的角度来看,这种方式应该也适用于多标签分类。

度量学习一直是我想了解的一个领域,就拿这篇论文做一个开始吧。

Distance Metric Learning

度量学习(metric learning)2,简言之:学习数据的嵌入表示,嵌入具有这样的性质,相似的数据点距离近不相似的数据点距离远。度量学习中常见的两种损失:对比损失和三元组损失,二者形式化的表示:
L c o n t ( x i , x j ; f ) = 1 { y i = y j } ∣ ∣ f i − f j ∣ ∣ 2 2 + 1 { y i ≠ y j } m a x ( 0 , m − ∣ ∣ f i − f j ∣ ∣ 2 ) 2 \mathcal{L}_{cont}(x_i, x_j; f) = \mathbb{1}\{y_i = y_j\}||f_i - f_j||_2^2 + \mathbb{1}\{y_i \neq y_j\}max(0, m - ||f_i - f_j||_2)^2 Lcont(xi,xj;f)=1{yi=yj}∣∣fifj22+1{yi=yj}max(0,m∣∣fifj2)2

L t r i ( x , x + , x − ; f ) = m a x ( 0 , ∣ ∣ f − f + ∣ ∣ 2 2 − ∣ ∣ f − f − ∣ ∣ 2 2 + m ) \mathcal{L}_{tri}(x, x^+, x^-; f) = max(0, ||f - f^+||_2^2 - ||f - f^-||_2^2 + m) Ltri(x,x+,x;f)=max(0,∣∣ff+22∣∣ff22+m)

其中 L c o n t \mathcal{L}_{cont} Lcont为对比损失(现在火起来的对比学习), L t r i \mathcal{L}_{tri} Ltri为三元组损失, f f f表示样本的嵌入。在对比损失中,要求来自同类别的样本距离近,不同类别的样本距离远;三元组损失中要求正( x + x^+ x+)、负( x − x^- x)样本到锚点( x x x,如搜图场景中的查询图)的距离要大于一定的阈值。

度量学习有一些现在很常见的应用,例如人脸识别、搜图等。度量学习的样本中通常只有一个负样本,容易导致收敛速度慢和局部最优的问题。难负样本挖掘(提一嘴:随着更多的实践,愈发觉得数据质量的重要性,如何构造好的数据集是一个值得研究的问题)能够减轻这些问题,但是如何找到难负样本本身就是一个难题。

与常见的三元组损失(triplet loss)中一个锚样本、一个正样本和一个负样本不一样,论文提出了一个 ( N + 1 ) (N+1) (N+1)元组的损失,来使一个正样本与 N − 1 N-1 N1个负样本区分开来。

Deep Metric Learning with Multiple Negative Examples

在三元组损失中,如果要使得损失尽可能低,显然有这么几种情况:

  • 缩短正样本与锚样本的距离;
  • 增大负样本与锚样本的距离;
  • 以上二者的结合。

从三元组损失的计算方式上也可以看出,再一次更新中只会比较锚样本与一个负样本,忽略了其他类别的负样本。这就导致:每次只能使锚样本远离一种负类,或许又被推到其他负类那里去了。最终学习到的嵌入可能会出现这样的情况:锚样本离训练数据中出现较多的负类远,而离某些负类又很近

当然,我们可以为锚样本配很多个三元组,囊括不同类别的负样本,这样在多轮、充足的训练后嵌入能够具有理想的性质。这样做就面临了不稳定以及收敛速度慢的问题。因此,文中就提出了 N + 1 N+1 N+1元组的损失,二者的区别如下图所示:
Triplet loss and (N+1)-tuplet loss

Deep metric learning with (left) triplet loss and (right) (N+1)-tuplet loss.

上图中红色的圆表示负样本,蓝色的表示锚样本和正样本。从左侧可以看出, N + 1 N+1 N+1元组损失的一个很简单的出发点:既然一个负类的样本不够,那就每个负类都拿一个样本出来,组成一个 N + 1 N+1 N+1的元组。但是在类别很多的场景(比如人脸识别),计算的复杂度过高。文章的重点就在于如何设计这样一个计算上可行的损失函数。

下图是三元组损失(a)、 ( N + 1 ) (N+1) (N+1)元组损失(b)及其改进后的损失©的一个对比。 N N N-pair-mc loss(multi-class N-pair loss)损失就是文章最后提出的损失。

N-pair-mc loss
Triplet loss, (N+1)-tuplet loss, and multi-class N-pair loss with training batch construction.

( N + 1 ) (N+1) (N+1)元组损失可以定义如下:
L ( { x , x + , { x i } i = 1 N − 1 } ; f ) = l o g ( 1 + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) \mathcal{L}(\{x, x^+, \{x_i\}_{i=1}^{N-1}\}; f) = log(1 + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+)) L({x,x+,{xi}i=1N1};f)=log(1+i=1N1exp(fTfifTf+))
N N N等于2的时候该损失是与三元组损失等价的。提一嘴,这个形式和softplus的形式是一样的:
s o f t p l u s ( x ) = l o g ( 1 + e x p ( x ) ) softplus(x) = log(1 + exp(x)) softplus(x)=log(1+exp(x))
( N + 1 ) (N+1) (N+1)元组的损失可以写为如下形式:
l o g ( 1 + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) = − l o g e x p ( f T f + ) e x p ( f T f + ) + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) log(1 + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+)) = - log \frac{exp(f^T f^+)} {exp(f^T f^+) + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+))} log(1+i=1N1exp(fTfifTf+))=logexp(fTf+)+i=1N1exp(fTfifTf+))exp(fTf+)
这样一看是不是就更顺眼了,这不就是多分类里的softmax loss嘛。

N-pair loss for efficient deep metric learning

论文提出了一种高效的批构造方法,以降低额外的计算开销。方法的名字叫multi-class N N N-pair loss( N N N-pair-mc),其构造方式如上图 ( c )所示。来个说文解字,道一道作者的解决方法。方法名中有个N-pair,就从这入手。假若我们有 N N N个pair:
{ ( x 1 , x 1 + ) , ⋯ , ( x N , x N + } , y i ≠ y j , ∀ i ≠ j \{(x_1, x_1^+), \cdots, (x_N, x_N^+\},\ y_i \neq y_j, \forall i \neq j {(x1,x1+),,(xN,xN+}, yi=yj,i=j
每个pair的样本来自不同的类别,在这 N N N个pair的基础上构建 N N N个元组 { S i } i = 1 N \{S_i\}_{i=1}^N {Si}i=1N,其中:
S i = { x i , x 1 + , x 2 + , ⋯ , x N + } S_i = \{x_i, x_1^+, x_2^+, \cdots, x_N^+\} Si={xi,x1+,x2+,,xN+}
其中 x i x_i xi就是锚样本。显然, S i S_i Si就是一个包含了一个 i i i类别正样本, N − 1 N-1 N1个其他类别负样本的 N + 1 N+1 N+1元组了。因此,对于一个由 N N N个查询组成的batch,只需要准备 2 N 2 N 2N个样本,即 N N N个锚样本和 N N N个对应类别的正样本,每个batch只需要** 2 N 2 N 2N次前向计算**样本的嵌入就可以了。而在三元组损失和 N + 1 N+1 N+1元组损失中分别是 3 N 3 N 3N ( N + 1 ) N (N+1) N (N+1)N。因此,对于 N N N个查询组成的batch,其损失可以如下计算:
L N − p a i r − m c ( { ( x i , x i + } i = 1 N ; f ) = 1 N ∑ i = 1 N l o g ( 1 + ∑ j ≠ i e x p ( f i T f j + − f i T f i + ) ) \mathcal{L}_{N-pair-mc}(\{(x_i, x_i^+\}_{i=1}^N ; f) = \frac{1} {N} \sum_{i=1}^N log (1 + \sum_{j \neq i} exp(f_i^T f_j^+ - f_i^T f_i^+)) LNpairmc({(xi,xi+}i=1N;f)=N1i=1Nlog(1+j=iexp(fiTfj+fiTfi+))
以上就是论文的主要内容了,当然论文中还提到了负类别挖掘,这个就暂且不提了。

总结

简言之,这篇论文将度量学习中常见的三元组损失中只有一个负样本扩展到每个样本中包含 N − 1 N-1 N1个负样本,并且为了计算的效率提出了 N N N-pair的batch构造方法以降低计算量。其实,如果在三元组损失的batch中精心设计各种类别样本的配比,比如每个batch只训练一个类别,是否也能达到类似的效果呢?

参考


  1. Rethinking the Role of Pre-ranking in Large-scale E-Commerce, KDD 2023. ↩︎

  2. 漫谈-Distance Metric Learning那些事儿:https://zhuanlan.zhihu.com/p/458114525. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/29799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

时间复杂度空间复杂度相关练习题

1.消失的数字 【题目】:题目链接 思路1:排序——》qsort快排——》时间复杂度O(n*log2n) 不符合要求 思路2:(0123...n)-(a[0]a[1][2]...a[n-2]) ——》 时间复杂度O(N)空间复杂度…

Promise和async/await的使用及其应用场景

Promise 和 async/await 都是用于处理 JavaScript 异步操作的机制,它们在处理异步代码和处理回调地狱方面提供了更清晰和可维护的方式。 Promise 使用及原理: Promise 是一种处理异步操作的方式,它可以在异步操作完成时进行响应&#xff0c…

(Python)计算R方

计算R方是统计学中的一项重要任务,它可以评估一个模型的拟合程度 Python是一种广泛使用的编程语言,也是计算R方的一个强大工具 import numpy as np from sklearn.metrics import r2_score # 生成一些模拟数据 y_true np.array([1, 2, 3, 4, 5]) y_pred …

第十七章 定义 HL7 的 DTL 数据转换 - 空映射代码

文章目录 第十七章 定义 HL7 的 DTL 数据转换 - 空映射代码空映射代码 第十七章 定义 HL7 的 DTL 数据转换 - 空映射代码 空映射代码 有些 HL7 应用程序使用空映射约定。根据此约定,源应用程序可以发送一个由两个连续双引号字符 ("") 组成的字段&#x…

Java Api实现Elasticsearch的滚动查询

解决ES每次只能查询一万条数据的问题 Overridepublic List<ESHandleDto> getVisitorsNum(String startTime, String endTime, String schoolName, String typeFunction) throws IOException {List<ESHandleDto> esHandleDtos new ArrayList<>();SearchReque…

spring cloud智慧工地源码(项目端+监管端+数据大屏+APP)

spring cloud智慧工地源码&#xff08;项目端监管端数据大屏APP&#xff09; 系统功能介绍 【智慧工地PC项目端功能总览】 一.项目人员管理 包括&#xff1a;信息管理、信息采集、证件管理、考勤管理、考勤明细、工资管理、现场统计、WIFI教育、工种管理、分包商管理、班组管…

STM32 HAL 驱动PM2.5传感器(GP2Y10AU气体检测模块)

目录 1、简介 2、CubeMX初始化配置 2.1 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 ADC外设配置 2.3 串口外设配置 2.4 项目生成 3、KEIL端程序整合 3.1 串口重映射 3.2 ADC数据采集 3.3 主函数代 3.4 效果展示 1、简介 本文通过STM32F103C8T6单片机通过HAL库方式对G…

RxJava的前世【RxJava系列之设计模式】

一. 前言 学习RxJava&#xff0c;少不了介绍它的设计模式。但我看大部分文章&#xff0c;都是先将其用法介绍一通&#xff0c;然后再结合其用法&#xff0c;讲解其设计模式。这样当然有很多好处&#xff0c;但我个人觉得&#xff0c;这种介绍方式&#xff0c;对于没有接触过Rx…

Qt 使用QLabel的派生类实现QLabel的双击响应

1 介绍 在QLabel中没有双击等事件响应&#xff0c;需要构建其派生类&#xff0c;自定义信号(signals)、重载事件函数(event)&#xff0c;最后在Qwidget中使用connect链接即可&#xff0c;进而实现响应功能。 对于其余没有需求事件响应的QObject同样适用。 此外&#xff0c;该功…

【HDFS】客户端读某个块时,如何对块的各个副本进行网络距离排序?

本文包含如下内容: ① 通过图解+源码分析/A1/B1/node1和 /A1/B2/node2 这两个节点的网络距离怎么算出来的 ② 客户端读文件时,副本的优先级。(怎么排序的,排序规则都有哪些?) ③ 我们集群发现的一个问题。 客户端读时,通过调用getBlockLocations RPC 获取文件的各个块。…

spring框架自带的http工具RestTemplate用法

1. RestTemplate是什么&#xff1f; RestTemplate是由Spring框架提供的一个可用于应用中调用rest服务的类它简化了与http服务的通信方式。 RestTemplate是一个执行HTTP请求的同步阻塞式工具类&#xff0c;它仅仅只是在 HTTP 客户端库&#xff08;例如 JDK HttpURLConnection&a…

python 相关框架事务开启方式

前言 对于框架而言&#xff0c;各式API接口少不了伴随着事务的场景&#xff0c;下面就列举常用框架的事务开启方法 一、Django import traceback from django.db import transaction from django.contrib.auth.models import User try:with transaction.atomic(): # 在with…

利用appium抓取app中的信息

一、appium简介 二、appium环境安装 三、联调测试环境 四、利用appium自动控制移动设备并提取数据

nginx location的执行规则和root/alias的区分

nginx location的执行规则和root/alias的区分 总结 看本篇文章不是教如何从0编写nginx配置&#xff0c;而是看懂已存在的nginx配置。 官方文档定义&#xff1a;location [ | ~ | ~* | ^~ ] uri { … } &#xff1a;严格匹配&#xff0c;且匹配成功则不继续往下&#xff0c;优…

【openGauss】分区表的介绍与使用

一、openGauss分区表介绍 在openGauss中&#xff0c;数据分区是在一个节点内部对数据按照用户指定的策略做进一步的水平分表&#xff0c;将表中的数据按照指定方式划分为多个互不重叠的部分。 对于大多数用户使用场景&#xff0c;分区表和普通表相比具有以下优点&#xff1a; …

年轻代频繁GC ParNew导致http变慢

背景介绍 某日下午大约四点多&#xff0c;接到合作方消息&#xff0c;线上环境&#xff0c;我这边维护的某http服务突然大量超时&#xff08;对方超时时间设置为300ms&#xff09;&#xff0c;我迅速到鹰眼平台开启采样&#xff0c;发现该服务平均QPS到了120左右&#xff0c;平…

希尔排序——C语言andPython

前言 步骤 代码 C语言 Python 总结 前言 希尔排序&#xff08;Shell Sort&#xff09;是一种改进的插入排序算法&#xff0c;它通过将数组分成多个子序列进行排序&#xff0c;逐步减小子序列的长度&#xff0c;最终完成整个数组的排序。希尔排序的核心思想是通过排序较远距…

SQL server 异地备份数据库

异地备份数据库 1.备份服务器中设置共享文件夹 2.源服务器数据库中添加异地备份代理作业 EXEC sp_configure show advanced options, 1;RECONFIGURE; EXEC sp_configure xp_cmdshell, 1;RECONFIGURE; declare machine nvarchar(50) 192.168.11.10 --服务器IP declare pa…

中科驭数亮相DPU峰会,分享HADOS软件生态实践和大数据计算方案,再获评“匠芯技术奖”

又是一年相逢时&#xff0c;8月4日&#xff0c;第三届DPU峰会在北京开幕&#xff0c;本届峰会由中国通信学会指导&#xff0c;江苏省未来网络创新研究院主办&#xff0c;SDNLAB社区承办&#xff0c;以“智驱创新芯动未来”为主题&#xff0c;沿袭技术创新、生态协同的共创效应&…

Vue-1.零基础学习Vue

当你从零开始学习 Vue.js 时&#xff0c;以下步骤可以帮助你系统地学习这个前端框架&#xff1a; 了解前端基础&#xff1a; 如果你对前端开发还不熟悉&#xff0c;可以先了解 HTML、CSS 和 JavaScript 的基础知识。这将为学习 Vue.js 奠定基础。 Vue.js 官方文档&#xff1…