论文分享|NeurIPS2022‘华盛顿大学|俄罗斯套娃表示学习(OpenAI使用的文本表示学习技术)

论文题目:Matryoshka Representation Learning

来源:NeurIPS2022/华盛顿大学+谷歌

方向:表示学习

开源地址:https://github.com/RAIVNLab/MRL

摘要

学习表征对于现代机器学习很重要,广泛用于很多下游任务。大多数情况下,每个下游任务的计算资源和需求都是未知的,因此对于某个下游任务,一个固定长度的表征可能会过大或过小。这引出一个问题:我们能否设计出一个灵活的表示方式,以适应具有不同计算资源的多个下游任务?

本文提出了俄罗斯套娃表示学习Matryoshka Representation Learning (MRL),编码不同粒度的信息,并让一个编码能够适应不同计算资源的下游任务。MRL最小限度地修改了当前的表示学习流程,没有引入额外的推理和部署开销。MRL学习到了至少和单独训练的低维表征一样准确的从粗到细的表征。MRL具有以下出色性能:(a)在保证ImageNet-1K分类任务精确率的前提下,将表征大小缩小为原来的14倍。(b)将ImageNet-1K和4K上的大规模检索速度加快到原来的14倍。©将长尾少样本分类的准确率提升了2%。(d)提升的同时也能保证鲁棒性。最后本文发现MRL可以很容易地扩展到各种模态的模型,比如视觉模型ViT/ResNet,视觉语言模型ALIGN,语言模型BERT等。

介绍

深度表征进行部署时一般分为两步:(1)昂贵但是固定的前向推理计算表征 (2)在下游应用中使用表征使用成本一般需要使用编码维度数据量(N)标签空间(L)三方面进行计算。对于大规模数据的网络应用,使用成本一般会超过计算成本。固定的表征大小使得不同任务都需要使用高维编码向量,尽管许多应用的资源和精度需求不同。

人类对于自然世界的理解有很自然的从粗到细的粒度,然而由于梯度训练的归纳偏置,深度学习模型倾向于向整个表征向量扩散信息。之前的方法一般使用三种方式来得到有弹性的表示模型:训练多个低维度模型(ResNet);联合优化不同容量的子网络[1,2];向量压缩(SVD)。但每种方法都因为训练和维护费用**,大量昂贵的对所有的数据前向推理,存储和内存成本昂贵的实时特征选择准确率的显著下降**很难满足大规模部署。通过编码不同粗细粒度的与独立训练准确率相当的表征,本文以最小的额外开销学习一种可以在推理过程中无需额外成本自适应部署的表征。

MRL通过嵌套的形式在同一个高维向量中对 O ( log ⁡ ( d ) ) O(\log(d)) O(log(d)) 维度的低维向量做显式优化。下图展示了核心技术,随着维度的增加,表征的粒度也越来越细。

img

本文主要关注最重要的两项任务,大规模分类和检索。对于分类任务,使用MRL训练的模型中的可变大小表征进行自适应级联(小于等于某个阈值则使用更高维度的向量进行分类**),显著降低了实现特定精度所需的编码的平均维数。对于检索任务,先使用最前面的维度进行粗排**,再使用更多的维度对粗排结果进行精排。同样还可以用在长尾持续学习以及判断样本间分类的困难程度和信息瓶颈

方法

MRL的目的是学习许多个小于等于 ⌊ log ⁡ d ⌋ \lfloor \log d\rfloor logd 的前 M M M 维表征,即总维度的前 1 2 k \frac 1 {2^k} 2k1 维。以分类任务为例,其实就是将每个维度得到的表征通过各自的分类头计算多分类交叉熵损失再相加。对于ImageNet来说,本文选取了 M = { 8 , 16 , . . . , 1024 , 2048 } M=\{8,16,...,1024,2048\} M={8,16,...,1024,2048} c m c_m cm 是每个维数的重要性,本文都设置为1,即看成同等重要。尽管仅对这些前 ⌊ log ⁡ d ⌋ \lfloor \log d\rfloor logd 维度向量进行了优化,对于这些维度之间的维度也能取得比较好的效果。

img

一种比较高效的做法是将每个投射头 W ( m ) ∈ R L × m W^{(m)}\in \mathbb R^{L\times m} W(m)RL×m 看成是一个大投射头的 W ∈ R L × d W\in \mathbb R^{L\times d} WRL×d 的一部分,即 W ( m ) = W 1 : m W^{(m)}=W_{1:m} W(m)=W1:m ,这种做法在大输出空间时尤其重要,本文称之为Efficient Matryoshka Representation Learning (MRL–E)

对于视觉模型、视觉语言模型中的对比学习,以及语言模型中的掩蔽语言模型,下一词预测等任务都可以转换为类似分类的方式进行处理。(毕竟万物皆是分类

应用

本文将MRL/MRL-E模型与单独训练的低维表征(FF),SVD分解,子网络[2]方法进行了比较

首先是分类任务。对于在ImageNet上训练的模型,线性分类准确率基本和FF保持一致,1-NN准确率甚至在低维时高于FF

img

对于大规模数据集上训练的模型也取得了很好的精度与速度间的平衡

img

对于适应性分类,期望的表征大小相比FF减小了14倍。

img

图像检索的结果也超越了baseline,最高超过了FF 3%。适应性图像检索也达到了效率和精度的权衡,16维度做粗排,2048维度做精排的准确率已经和直接使用2048维度做排序的精度还高,但计算量大幅减小。值得一提的是本文提出了一个漏斗检索方法,即使用逐渐增大的维度16-32-64-128-256-2048 对前200-100-50-25-10个样本的逐步重排,这种方法可以省去调参,应用比较方便。

img

分析与消融

鲁棒性:在分类和检索任务都具有一定鲁棒性。鲁棒性研究最好选用最近邻分类或检索而不是线性探针

少样本和长尾学习:长尾分布中的新类别可以有2%的提升。且高维度的表征对于比较难的类别分类更准确

不同维度间的分歧:某些类别和实例低维效果可能超过高维,MRL可能可以被用作分析信息瓶颈的工具

超类准确率:维度越高,超类准确率越高

img

结论

可能的优化方向:(1)优化嵌套损失的权重 (2)在不同的保真度下使用不同的损失,旨在解决自适应部署的特定需求,例如,8维的高召回率和2048维的鲁棒性。(3)在矩阵数据表示之上学习搜索数据结构,如可微k-d树,以实现数据集和表征感知检索。(4)最后,结合多目标MRL与端到端可学习搜索数据结构的联合优化,实现数据驱动的面向大规模搜索应用的自适应大规模检索

笔者评价:这篇文章主要是对检索过程中KNN相似度匹配部分的速度优化,且能够保持不同维度表征的性能和一致性,可以让大公司为不同用户(个人,开发者)的各种编码需求提供一个相对通用的服务


大家好,我是NLP研究者BrownSearch,如果你觉得本文对你有帮助的话,不妨点赞收藏支持我的创作,您的正反馈是我持续更新的动力!如果想了解更多LLM/检索的知识,记得关注我!

引用

[1]Cai H, Gan C, Wang T, et al. Once-for-all: Train one network and specialize it for efficient deployment[J]. arXiv preprint arXiv:1908.09791, 2019.

[2]Yu J, Yang L, Xu N, et al. Slimmable neural networks[J]. arXiv preprint arXiv:1812.08928, 2018.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/45872.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java配置nginx网络安全,防止国外ip访问,自动添加黑名单,需手动重新加载nginx

通过访问日志自动添加国外ip黑名单 创建一个类,自己添加一个main启动类即可测试 import lombok.AccessLevel; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.json.JSONArray; import org.json.JSONObject; import org.sp…

【学习笔记】Redis学习笔记——第10章 RDB持久化

第10章 RDB持久化 RDB是用来做持久化的二进制压缩文件 10.1 RDB文件的创建与载入 1>SAVE命令阻塞主线程创建。 2>EGSAVE开子线程创建。 3>优先使用AOF进行初始化数据库,否则,使用RDB文件初始化,因为AOF文件的写入更加频繁&#x…

面试经验之谈

优质博文:IT-BLOG-CN ​通常面试官会把每一轮面试分为三个环节:① 行为面试 ② 技术面试 ③ 应聘者提问 行为面试环节 面试开始的5~10分钟通常是行为面试的时间,面试官会参照简历和你的自我介绍了解应聘者的过往经验和项目经历。由于面试官…

C++catch (...)陈述

catch (...)陈述 例外处理可以有多个catch&#xff0c;如果catch后的小括弧里面放...&#xff0c;就表示不限型态种类的任何例外。 举例如下 #include <iostream>int main() {int i -1;try {if (i > 0) {throw 0;}throw 2.0;}catch (const int e) {std::cout <…

nodejs模板引擎(一)

在 Node.js 中使用模板引擎可以让您更轻松地生成动态 HTML 页面&#xff0c;通过将静态模板与动态数据结合&#xff0c;您可以创建可维护且易于扩展的 Web 应用程序。以下是一个使用 Express 框架和 EJS 模板引擎的基本示例&#xff1a; 安装必要的依赖&#xff1a; 首先&#…

分享浏览器被hao123网页劫持,去除劫持的方式

昨天看python相关的自动化工作代码时&#xff0c;发现谷歌浏览器被hao123劫持了&#xff0c;把那些程序删了也不管用 方法1&#xff1a;删除hao123注册表&#xff0c;这个方式不太好用&#xff0c;会找不到注册表 方法2&#xff1a;看浏览器快捷方式的属性页面&#xff0c;一…

【C++】入门基础(命名空间、缺省参数、函数重载)

目录 一.命名空间&#xff1a;namespace 1.namespace的价值 2.namespace的定义 3.namespace的使用方法 3.1 域解析运算符:: 3.2 using展开 3.3 using域解析运算符 二.输入输出 三.缺省参数 四.函数重载 1.参数类型不同 2.参数个数不同 3.参数顺序不同 一.命名空间&…

APP专项测试之网络测试

背景 当前app网络环境比较复杂&#xff0c;越来越多的公共wifi&#xff0c;网络制式有2G、3G、4G网络&#xff0c;会对用户使用app造成一定影响&#xff1b;当前app使用场景多变&#xff0c;如进地铁、上公交、进电梯等&#xff0c;使得弱网测试显得尤为重要&#xff1b; 网络正…

链路追踪系列-02.演示zipkin

当本机启动docker es zipkinServer之后&#xff1a; 启动3个项目&#xff1a;先eureka-server&#xff0c;再 PaymentMain8001,… 浏览器打开&#xff1a;http://localhost:9001/consumer/payment/zipkin consumer代码 &#xff1a; provider: 此时查询es:

uboot如何选择启动设备

cpu选择启动设备有两种方式 1、通过bootpin选择某个设备 比如从SD卡启动、EMMC、USB启动。 2、通过bootpin选择某个顺序 比如&#xff1a; SD、SPI、NANDSPI、NAND、SD

3-2 多层感知机的从零开始实现

import torch from torch import nn from d2l import torch as d2lbatch_size 256 # 批量大小为256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size) # load进来训练集和测试集初始化模型参数 回想一下&#xff0c;Fashion-MNIST中的每个图像由 28 28 784…

学习C++,应该循序渐进的看哪些书?

学习C是一个循序渐进的过程&#xff0c;需要根据自己的基础和目标来选择合适的书籍。以下是一个推荐的学习路径&#xff0c;包含了从入门到进阶的书籍&#xff1a; 1. 入门阶段 《C Primer Plus 第6版 中文版》 推荐理由&#xff1a;这本书同样适合C零基础的学习者&#xff0…

python运行环境在新旧电脑间迁移

目录 方法1. 直接复制虚拟环境文件夹&#xff1a;方法2. 导出和导入依赖项&#xff1a;方法3. 用 Docker 镜像&#xff1a; 当您需要在不同电脑之间迁移 Python 运行环境时&#xff0c;可以采用以下方法之一&#xff1a; 方法1. 直接复制虚拟环境文件夹&#xff1a; 将整个虚…

[CISCN2018]2ex

啊!好恶心的mips寄存器 好多IDA都查不到,这寄存器~! fuck! 但是这种寄存器一般的题都不难 这道题就是 我用平常的方法,没找到 左边函数一个一个点 看见这里0X3F base64 密文呢? 我giao 外面的txt文件里面 脚本 import base64 import string# 定义你的自定义字符集 st…

PHP语言教程与实战案例

PHP是一种广泛使用的开源脚本语言&#xff0c;尤其适用于Web开发并可嵌入HTML中。它的语法吸收了C语言、Java和Perl的特点&#xff0c;易于学习&#xff0c;功能强大。本文将带领你从基础语法入手&#xff0c;通过一系列实用案例&#xff0c;逐步掌握PHP的核心概念和技巧。 项…

使用JDBC实现事务管理与隔离级别设置

使用JDBC实现事务管理与隔离级别设置 在Java开发中&#xff0c;JDBC&#xff08;Java Database Connectivity&#xff09;是常用的数据库连接方式。在处理数据库操作时&#xff0c;事务管理和隔离级别设置是保证数据一致性和可靠性的关键。本篇博客将通过示例代码&#xff0c;…

并发编程-11线程池详解

一 线程池基础和使用 1.1 什么是线程池 “线程池”&#xff0c;顾名思义就是一个线程缓存&#xff0c;线程是稀缺资源&#xff0c;如果被无限制的创建&#xff0c;不仅会消耗系统资源&#xff0c;还会降低系统的稳定性&#xff0c;因此Java中提供线程池对线程进行统一分配、调优…

聊点基础---Java和.NET开发技术异同全方位分析

1. C#语言基础 1.1 C#语法概览 欢迎来到C#的世界&#xff01;对于刚从Java转过来的开发者来说&#xff0c;你会发现C#和Java有很多相似之处&#xff0c;但C#也有其独特的魅力和强大之处。让我们一起来探索C#的基本语法&#xff0c;并比较一下与Java的异同。 程序结构 C#程序…

美团收银Android一面凉经(2024)

美团收银Android一面凉经(2024) 笔者作为一名双非二本毕业7年老Android, 最近面试了不少公司, 目前已告一段落, 整理一下各家的面试问题, 打算陆续发布出来, 供有缘人参考。今天给大家带来的是《美团收银Android一面凉经(2024)》。 应聘岗位: 美团餐饮PaaS平台Android开发工程师…

pnpm 如何安装指定版本

要安装特定版本的pnpm&#xff0c;可以使用npm命令来全局安装特定版本的pnpm&#xff0c;例如&#xff1a; npm install -g pnpm2.0.0在上面的命令中&#xff0c;使用了2.0.0来指定安装2.0.0版本的pnpm。您可以将2.0.0替换为您需要安装的版本号。 如果您使用的是yarn&#xf…