Improved Deep Metric Learning with Multi-class N-pair Loss Objective

Improved Deep Metric Learning with Multi-class N-pair Loss Objective

来源:

  • NIPS’2016
  • NEC Laboratories America

文章目录

  • Improved Deep Metric Learning with Multi-class N-pair Loss Objective
    • Distance Metric Learning
    • Deep Metric Learning with Multiple Negative Examples
      • N-pair loss for efficient deep metric learning
    • 总结
    • 参考

找到这篇论文是因为看了淘宝搜索出品的论文Rethinking the Role of Pre-ranking in Large-scale E-Commerce 1,文中就提到了传统的list-wise损失 不适用于列表中存在多个正样本的场景。从样本构造的角度来看,这种方式应该也适用于多标签分类。

度量学习一直是我想了解的一个领域,就拿这篇论文做一个开始吧。

Distance Metric Learning

度量学习(metric learning)2,简言之:学习数据的嵌入表示,嵌入具有这样的性质,相似的数据点距离近不相似的数据点距离远。度量学习中常见的两种损失:对比损失和三元组损失,二者形式化的表示:
L c o n t ( x i , x j ; f ) = 1 { y i = y j } ∣ ∣ f i − f j ∣ ∣ 2 2 + 1 { y i ≠ y j } m a x ( 0 , m − ∣ ∣ f i − f j ∣ ∣ 2 ) 2 \mathcal{L}_{cont}(x_i, x_j; f) = \mathbb{1}\{y_i = y_j\}||f_i - f_j||_2^2 + \mathbb{1}\{y_i \neq y_j\}max(0, m - ||f_i - f_j||_2)^2 Lcont(xi,xj;f)=1{yi=yj}∣∣fifj22+1{yi=yj}max(0,m∣∣fifj2)2

L t r i ( x , x + , x − ; f ) = m a x ( 0 , ∣ ∣ f − f + ∣ ∣ 2 2 − ∣ ∣ f − f − ∣ ∣ 2 2 + m ) \mathcal{L}_{tri}(x, x^+, x^-; f) = max(0, ||f - f^+||_2^2 - ||f - f^-||_2^2 + m) Ltri(x,x+,x;f)=max(0,∣∣ff+22∣∣ff22+m)

其中 L c o n t \mathcal{L}_{cont} Lcont为对比损失(现在火起来的对比学习), L t r i \mathcal{L}_{tri} Ltri为三元组损失, f f f表示样本的嵌入。在对比损失中,要求来自同类别的样本距离近,不同类别的样本距离远;三元组损失中要求正( x + x^+ x+)、负( x − x^- x)样本到锚点( x x x,如搜图场景中的查询图)的距离要大于一定的阈值。

度量学习有一些现在很常见的应用,例如人脸识别、搜图等。度量学习的样本中通常只有一个负样本,容易导致收敛速度慢和局部最优的问题。难负样本挖掘(提一嘴:随着更多的实践,愈发觉得数据质量的重要性,如何构造好的数据集是一个值得研究的问题)能够减轻这些问题,但是如何找到难负样本本身就是一个难题。

与常见的三元组损失(triplet loss)中一个锚样本、一个正样本和一个负样本不一样,论文提出了一个 ( N + 1 ) (N+1) (N+1)元组的损失,来使一个正样本与 N − 1 N-1 N1个负样本区分开来。

Deep Metric Learning with Multiple Negative Examples

在三元组损失中,如果要使得损失尽可能低,显然有这么几种情况:

  • 缩短正样本与锚样本的距离;
  • 增大负样本与锚样本的距离;
  • 以上二者的结合。

从三元组损失的计算方式上也可以看出,再一次更新中只会比较锚样本与一个负样本,忽略了其他类别的负样本。这就导致:每次只能使锚样本远离一种负类,或许又被推到其他负类那里去了。最终学习到的嵌入可能会出现这样的情况:锚样本离训练数据中出现较多的负类远,而离某些负类又很近

当然,我们可以为锚样本配很多个三元组,囊括不同类别的负样本,这样在多轮、充足的训练后嵌入能够具有理想的性质。这样做就面临了不稳定以及收敛速度慢的问题。因此,文中就提出了 N + 1 N+1 N+1元组的损失,二者的区别如下图所示:
Triplet loss and (N+1)-tuplet loss

Deep metric learning with (left) triplet loss and (right) (N+1)-tuplet loss.

上图中红色的圆表示负样本,蓝色的表示锚样本和正样本。从左侧可以看出, N + 1 N+1 N+1元组损失的一个很简单的出发点:既然一个负类的样本不够,那就每个负类都拿一个样本出来,组成一个 N + 1 N+1 N+1的元组。但是在类别很多的场景(比如人脸识别),计算的复杂度过高。文章的重点就在于如何设计这样一个计算上可行的损失函数。

下图是三元组损失(a)、 ( N + 1 ) (N+1) (N+1)元组损失(b)及其改进后的损失©的一个对比。 N N N-pair-mc loss(multi-class N-pair loss)损失就是文章最后提出的损失。

N-pair-mc loss
Triplet loss, (N+1)-tuplet loss, and multi-class N-pair loss with training batch construction.

( N + 1 ) (N+1) (N+1)元组损失可以定义如下:
L ( { x , x + , { x i } i = 1 N − 1 } ; f ) = l o g ( 1 + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) \mathcal{L}(\{x, x^+, \{x_i\}_{i=1}^{N-1}\}; f) = log(1 + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+)) L({x,x+,{xi}i=1N1};f)=log(1+i=1N1exp(fTfifTf+))
N N N等于2的时候该损失是与三元组损失等价的。提一嘴,这个形式和softplus的形式是一样的:
s o f t p l u s ( x ) = l o g ( 1 + e x p ( x ) ) softplus(x) = log(1 + exp(x)) softplus(x)=log(1+exp(x))
( N + 1 ) (N+1) (N+1)元组的损失可以写为如下形式:
l o g ( 1 + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) = − l o g e x p ( f T f + ) e x p ( f T f + ) + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) log(1 + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+)) = - log \frac{exp(f^T f^+)} {exp(f^T f^+) + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+))} log(1+i=1N1exp(fTfifTf+))=logexp(fTf+)+i=1N1exp(fTfifTf+))exp(fTf+)
这样一看是不是就更顺眼了,这不就是多分类里的softmax loss嘛。

N-pair loss for efficient deep metric learning

论文提出了一种高效的批构造方法,以降低额外的计算开销。方法的名字叫multi-class N N N-pair loss( N N N-pair-mc),其构造方式如上图 ( c )所示。来个说文解字,道一道作者的解决方法。方法名中有个N-pair,就从这入手。假若我们有 N N N个pair:
{ ( x 1 , x 1 + ) , ⋯ , ( x N , x N + } , y i ≠ y j , ∀ i ≠ j \{(x_1, x_1^+), \cdots, (x_N, x_N^+\},\ y_i \neq y_j, \forall i \neq j {(x1,x1+),,(xN,xN+}, yi=yj,i=j
每个pair的样本来自不同的类别,在这 N N N个pair的基础上构建 N N N个元组 { S i } i = 1 N \{S_i\}_{i=1}^N {Si}i=1N,其中:
S i = { x i , x 1 + , x 2 + , ⋯ , x N + } S_i = \{x_i, x_1^+, x_2^+, \cdots, x_N^+\} Si={xi,x1+,x2+,,xN+}
其中 x i x_i xi就是锚样本。显然, S i S_i Si就是一个包含了一个 i i i类别正样本, N − 1 N-1 N1个其他类别负样本的 N + 1 N+1 N+1元组了。因此,对于一个由 N N N个查询组成的batch,只需要准备 2 N 2 N 2N个样本,即 N N N个锚样本和 N N N个对应类别的正样本,每个batch只需要** 2 N 2 N 2N次前向计算**样本的嵌入就可以了。而在三元组损失和 N + 1 N+1 N+1元组损失中分别是 3 N 3 N 3N ( N + 1 ) N (N+1) N (N+1)N。因此,对于 N N N个查询组成的batch,其损失可以如下计算:
L N − p a i r − m c ( { ( x i , x i + } i = 1 N ; f ) = 1 N ∑ i = 1 N l o g ( 1 + ∑ j ≠ i e x p ( f i T f j + − f i T f i + ) ) \mathcal{L}_{N-pair-mc}(\{(x_i, x_i^+\}_{i=1}^N ; f) = \frac{1} {N} \sum_{i=1}^N log (1 + \sum_{j \neq i} exp(f_i^T f_j^+ - f_i^T f_i^+)) LNpairmc({(xi,xi+}i=1N;f)=N1i=1Nlog(1+j=iexp(fiTfj+fiTfi+))
以上就是论文的主要内容了,当然论文中还提到了负类别挖掘,这个就暂且不提了。

总结

简言之,这篇论文将度量学习中常见的三元组损失中只有一个负样本扩展到每个样本中包含 N − 1 N-1 N1个负样本,并且为了计算的效率提出了 N N N-pair的batch构造方法以降低计算量。其实,如果在三元组损失的batch中精心设计各种类别样本的配比,比如每个batch只训练一个类别,是否也能达到类似的效果呢?

参考


  1. Rethinking the Role of Pre-ranking in Large-scale E-Commerce, KDD 2023. ↩︎

  2. 漫谈-Distance Metric Learning那些事儿:https://zhuanlan.zhihu.com/p/458114525. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/29799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

时间复杂度空间复杂度相关练习题

1.消失的数字 【题目】:题目链接 思路1:排序——》qsort快排——》时间复杂度O(n*log2n) 不符合要求 思路2:(0123...n)-(a[0]a[1][2]...a[n-2]) ——》 时间复杂度O(N)空间复杂度…

spring cloud智慧工地源码(项目端+监管端+数据大屏+APP)

spring cloud智慧工地源码(项目端监管端数据大屏APP) 系统功能介绍 【智慧工地PC项目端功能总览】 一.项目人员管理 包括:信息管理、信息采集、证件管理、考勤管理、考勤明细、工资管理、现场统计、WIFI教育、工种管理、分包商管理、班组管…

STM32 HAL 驱动PM2.5传感器(GP2Y10AU气体检测模块)

目录 1、简介 2、CubeMX初始化配置 2.1 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 ADC外设配置 2.3 串口外设配置 2.4 项目生成 3、KEIL端程序整合 3.1 串口重映射 3.2 ADC数据采集 3.3 主函数代 3.4 效果展示 1、简介 本文通过STM32F103C8T6单片机通过HAL库方式对G…

Qt 使用QLabel的派生类实现QLabel的双击响应

1 介绍 在QLabel中没有双击等事件响应,需要构建其派生类,自定义信号(signals)、重载事件函数(event),最后在Qwidget中使用connect链接即可,进而实现响应功能。 对于其余没有需求事件响应的QObject同样适用。 此外,该功…

利用appium抓取app中的信息

一、appium简介 二、appium环境安装 三、联调测试环境 四、利用appium自动控制移动设备并提取数据

年轻代频繁GC ParNew导致http变慢

背景介绍 某日下午大约四点多,接到合作方消息,线上环境,我这边维护的某http服务突然大量超时(对方超时时间设置为300ms),我迅速到鹰眼平台开启采样,发现该服务平均QPS到了120左右,平…

希尔排序——C语言andPython

前言 步骤 代码 C语言 Python 总结 前言 希尔排序(Shell Sort)是一种改进的插入排序算法,它通过将数组分成多个子序列进行排序,逐步减小子序列的长度,最终完成整个数组的排序。希尔排序的核心思想是通过排序较远距…

SQL server 异地备份数据库

异地备份数据库 1.备份服务器中设置共享文件夹 2.源服务器数据库中添加异地备份代理作业 EXEC sp_configure show advanced options, 1;RECONFIGURE; EXEC sp_configure xp_cmdshell, 1;RECONFIGURE; declare machine nvarchar(50) 192.168.11.10 --服务器IP declare pa…

中科驭数亮相DPU峰会,分享HADOS软件生态实践和大数据计算方案,再获评“匠芯技术奖”

又是一年相逢时,8月4日,第三届DPU峰会在北京开幕,本届峰会由中国通信学会指导,江苏省未来网络创新研究院主办,SDNLAB社区承办,以“智驱创新芯动未来”为主题,沿袭技术创新、生态协同的共创效应&…

【打印整数二进制的奇数位和偶数位】

打印整数二进制的奇数位和偶数位 1.题目 获取一个整数二进制序列中所有的偶数位和奇数位,分别打印出二进制序列 2.题目分析 打印一个整数的二进制位中的偶数位和奇数位,可以对整数进行移位操作,再将移位的二进制位与1进行&操作。 按位&a…

【Azure】office365邮箱测试的邮箱账号因频繁连接邮箱服务器而被限制连接 引起邮箱显示异常

azure微软office365邮箱会对频繁连接自身邮箱服务器的IP地址进行,连接邮箱服务器IP限制,也就是黑名单,释放时间不确定,但至少一天及以上。 解决办法,换一个IP,或者新注册一个office365邮箱再重试。 以下是…

Java课题笔记~ AspectJ 对 AOP 的实现(掌握)

AspectJ 对 AOP 的实现(掌握) 对于 AOP 这种编程思想,很多框架都进行了实现。Spring 就是其中之一,可以完成面向切面编程。然而,AspectJ 也实现了 AOP 的功能,且其实现方式更为简捷,使用更为方便,而且还支…

JVM 类加载和垃圾回收

JVM 1. 类加载1.1 类加载过程1.2 双亲委派模型 2. 垃圾回收机制2.1 死亡对象的判断算法2.2 垃圾回收算法 1. 类加载 1.1 类加载过程 对应一个类来说, 它的生命周期是这样的: 其中前 5 步是固定的顺序并且也是类加载的过程,其中中间的 3 步我们都属于连接&#xf…

用node.js搭建一个视频推流服务

由于业务中有不少视频使用的场景,今天来说说如何使用node完成一个视频推流服务。 先看看效果: 这里的播放的视频是一个多个Partial Content组合起来的,每个Partial Content大小是1M。 一,项目搭建 (1)初…

macOS下Django环境搭建-docker运行Django

1. macOS升级pip /Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip 2. 卸载Python3.9.5版本 $ sudo rm -rf /usr/local/bin/python3 $ sudo rm -rf /usr/local/bin/pip3 $ sudo rm -rf /Library/Frameworks/Python.framework 3. 安装P…

微服务——ES实现自动补全

效果展示 在搜索框根据拼音首字母进行提示 拼音分词器 和IK中文分词器一样的用法,按照下面的顺序执行。 # 进入容器内部 docker exec -it elasticsearch /bin/bash# 在线下载并安装 ./bin/elasticsearch-plugin install https://github.com/medcl/elasticsearch…

【二分】CF1623 C

Problem - 1623C - Codeforces 题意: 思路: 肯定是二分,我们去二分最小值,然后check的时候最小值要大于mid check的时候要让最小值尽可能大 注意到我们不需要去管最大值,只需要最小值尽可能大就好了,因…

dirsearch_暴力扫描网页结构

python3 dirsearch 暴力扫描网页结构(包括网页中的目录和文件) 下载地址:https://gitee.com/xiaozhu2022/dirsearch/repository/archive/master.zip 下载解压后,在dirsearch.py文件窗口,打开终端(任务栏…

SpringBoot案例-部门管理-查询

查看页面原型,明确需求需求 页面原型 需求分析 阅读接口文档 接口文档链接如下: https://onedrive.live.com/?cidC62793E731F0C1BE&idC62793E731F0C1BE%2148 思路分析 用户发送请求,交由对应的Controller类进行处理,Con…

通讯协议034——全网独有的OPC HDA知识一之聚合(三)时间加权平均

本文简单介绍OPC HDA规范的基本概念,更多通信资源请登录网信智汇(wangxinzhihui.com)。 本节旨在详细说明HDA聚合的要求和性能。其目的是使HDA聚合标准化,以便HDA客户端能够可靠地预测聚合计算的结果并理解其含义。如果用户需要聚合中的自定义功能&…