Embedding模型提升效果的方法之一:Whitening和pooling

0. 前言

Embedding模型的主流框架基本上分为三类——基于bert结构的,基于GPT结构的和基于T5结构的,当然这些结构都是Transformer的变形。对于Embedding模型,使用bert结构目前看是最好的。有篇论文论文对基于bert的Embedding模型和基于GPT的Embedding模型做过比较,暂时找不到了,后续找到会附上。另外本人也对基于bert的embedding模型和基于GPT的embedding模型做了比较试验,结果表明基于bert的embedding模型完胜。

要让 embedding 模型性能提升,除了在模型结构和训练数据上做文章之外,还可以使用Whitening方法,特殊的pooling方法和Simcse方法来提升效果。本文先介绍Whitening方法和特殊的pooling方法,下一篇介绍Simcse方法。

1. Whitening

1.1 为什么可以用白化方法提升模型效果

白化操作不仅可以提升模型效果,还可以对句子向量进行降维

白化(whitening)方法之所以能够在embedding模型上产生正向效果是因为我们通常会用两个句子向量的余弦相似度来衡量这两个句子的相似性。但是由于类似由BERT和GPT这样的预训练语言模型得到的句子向量往往是具备各向异性的,表现状态就是向量会不均匀分布,且充斥在一个狭窄的锥形空间下。但是,具备各向异性的向量之间直接计算余弦相似度会出现一些偏差,这就导致 embedding 模型的表现变差。所以,我们只要将各向异性的句子向量转化为一个各向同性的句子向量就可以提升 embedding 模型的效果。此时就用到了白化操作。

在这里插入图片描述

1.2 余弦相似度和各向异性

  • consin
    假设 x x x y y y 两个向量,维度都是 R d R^d Rd。那么,利用cosine的计算方法,他们的相似度为:
    在这里插入图片描述
    上述方程 (1) 仅在坐标基(二维向量)为标准正交基时才成立。余弦角度具有明显的几何意义,但方程(1) 是基于运算的,它取决于所选的坐标基。因此,内积的坐标公式随着坐标基的变化而变化,余弦值的坐标公式也会随之变化。

    (Li et al., 2020) 验证了来自 BERT (Devlin et al., 2019) 的句子嵌入虽然没有得到适当的利用,但已经包含了足够的语义。在这种情况下,如果在操作方程(1) 计算语义相似度的余弦值时句子向量表现不佳,原因可能是句子向量所属的坐标基不是标准正交基。从统计学的角度,我们可以推断,当我们为一组向量选择基时,应该保证每个基向量是独立且一致的。如果这组基是标准正交基,则相应的向量组应该具备各向同性。

    综上所述,上述启发式假设详尽地表明:如果一组向量满足各向同性,我们可以假设从标准正交基推导出来,这也表明我们可以通过方程 (1) 计算余弦相似度。否则,如果它是各向异性的,我们需要对原始向量进行变换以某种方式嵌入句子以强制它是各向同性的,然后使用等式 (1) 计算余弦相似度。

  • 各向异性

    定义:各向异性是指在不同的方向上物理性质(表达含义)不同,各向同性是指不同的方向上物理性质相同。

    BERT和GPT的各项异性是怎么产生的:假设一个句子的向量为 { x i } i = 1 N \{x_i\}_{i=1}^N {xi}i=1N,某2个字的向量分别为 x j = [ x j 1 , x j 2 , x j 3 , … , x j n ] x_j=[x_j^1,x_j^2,x_j^3,\dots,x_j^n] xj=[xj1,xj2,xj3,,xjn] x h = [ x h 1 , x h 2 , x h 3 , … , x h n ] x_h=[x_h^1,x_h^2,x_h^3,\dots,x_h^n] xh=[xh1,xh2,xh3,,xhn],其中 可以理解为参数句子长度sequence_length。由于 BERT 的Token Embedding与Position Embedding的设计结构,导致了生成的句子向量不仅仅包含某单一token的MASK信息,同时还具备了这个token在不同位置所代表的Position信息,这直接为各向异性创作了条件。简单理解,假设 x j 1 x_j^1 xj1 x j 2 x_j^2 xj2 分别代表这个 token 的 mask 信息和 position 信息(实际不是这样简单的),这两个维度就是不同方向有不同的性质。再举一个反例,one-hot词向量就不具备各向异性,而是具备了各向同性的特点。

1.3 白化计算

经过上述的分析,想要让基于 bert 的 embedding 模型生成的句子向量用余弦相似度正确表示两个句子的相似性就得将句子向量转换到标准正交基下面

  • 标准正交基
    我们知道,对于两个向量 A A A B B B来说,如果 A ⋅ B = 0 A \cdot B=0 AB=0,那么,我们称这两个向量正交(零向量与任何向量正交)。 我们知道,在n维的欧式空间中,由n个向量组成的正交向量组称为正交基;由单位向量组成的正交基称为标准正交基。

已知:
在这里插入图片描述

在这里插入图片描述
A A A B B B 都不是0向量的时候,要让 A ⋅ B = 0 A \cdot B=0 AB=0,则 c o s ( A , B ) = 0 cos(A,B)=0 cos(A,B)=0,也就是

∑ i = 1 d a i b i = 0 \sum_{i=1}^da_ib_i=0 i=1daibi=0,而 ∑ i = 1 d a i b i \sum_{i=1}^da_ib_i i=1daibi 表示向量 A A A B B B 的协方差。

此时问题转换为:

已知原句子向量矩阵为 X X X,协方差矩阵为 C C C,目标是将 X X X 转换为协方差为0的向量矩阵 Y Y Y Y Y Y 的协方差矩阵为 D D D,求转换矩阵为 P P P

其中 X X X 可表示为:
在这里插入图片描述
其中 n n n 为sequence length, a a a b b b 代表不同的维度。

然后根据协方差的计算公式,可得:

在这里插入图片描述
我们可以看到这个矩阵对角线上的分别是两个变量的方差,而其它元素是 a 和 b 的协方差。两者被统一到了一个矩阵里。 我们很容易被推广到一般情况:

设我们有 m 个 n 维数据记录,将其排列成矩阵 X m , n X_{m,n} Xm,n ,设 C = 1 m X X T C = \frac{1}{m}XX^T C=m1XXT,则 C C C 是一个对称矩阵,其对角线分别对应各个变量的方差,而第 i 行 j 列和 j 行 i 列元素相同,表示 i 和 j 两个变量的协方差。

由此可知,我们需要将除对角线外的其它元素化为 0,并且在对角线上将元素按大小从上到下排列(变量方差尽可能大),这里就是将协方差转为一个单位矩阵,也就是矩阵的对角化。

推导一下 D D D C C C 的关系:
在这里插入图片描述
现在的目标变成了让 D D D 变成一个对角矩阵,这样的话 Y Y Y 的协方差就都为0了。并且对角元素按从大到小依次排列,那么 P P P 的前 K K K 行就是要寻找的基,用 P P P 的前 K K K 行组成的矩阵乘以 X X X 就使得 X X X N N N 维降到了 K K K 维并满足上述优化条件。

回到 embedding 模型本身的输出,根据上述的协方差矩阵,假设有一组句子向量,也可以写为行向量 { x } i = 1 N \{x\}_{i=1}^N {x}i=1N,在对它做线性变换之后生成一个均值为0、协方差矩阵为单位阵的目标向量 { x ~ } i = 1 N \{\tilde x\}_{i=1}^N {x~}i=1N

在这里插入图片描述
其中:
在这里插入图片描述

下面求解 W W W,将原始数据的协方差记为:
在这里插入图片描述
由上面推导出的 D = P C P T D=PCP^T D=PCPT,可以得到 ∑ ~ = W ∑ W T \tilde \sum=W\sum W^T ~=WWT,而我们的目标是 W ∑ W T = I W\sum W^T=I WWT=I,于是可得:
在这里插入图片描述
我们知道 ∑ \sum 是一个正定对称矩阵,正定对称矩阵都具有如下形式的SVD分解:
在这里插入图片描述
其中 U U U 是一个正交矩阵, ∧ \land 是一个正对角矩阵,则可以让 W − 1 = ∧ U T W^{-1}=\sqrt{\land}U^T W1= UT,则可以得到:

在这里插入图片描述
由于 ∧ \land U U U 均可以由 ∑ \sum 求得,所以 W W W 就被求出来了。

1.4 代码实现

def compute_kernel_bias(vecs, n_components=None):"""计算kernel和biasvecs.shape = [num_samples, embedding_size],最后的变换:y = (x + bias).dot(kernel):return kernel, bias"""if isinstance(vecs, list):vecs = np.concatenate(vecs, axis=0)mu = vecs.mean(axis=0, keepdims=True)cov = np.cov(vecs.T)u, s, vh = np.linalg.svd(cov)W = np.dot(u, np.diag(1 / np.sqrt(s)))print(W)print(-mu)if n_components is not None:return W[:, :n_components], -muelse:return W, -mu

【论文解读】BERT Whitening
whitening计算详解

2. pooling

模型生成文本的 embedding 时会用 pooling 的方法进行降维,但是在降维操作的时候常见的mean pooling 和 last token pooling都有一定的局限性,比如都会忽略位置信息,如果使用position weighted mean pooling将位置信息加进来会有更好的效果。

hidden_state的形状为(batch_size, sequence_length, hidden_size),mean pooling 就是在sequence_length上进行平均,生成形状为(batch_size, hidden_size) 的embedding 矩阵。

而last token pooling是直接取 “[CLS]” 字短的embedding 作为整个文本的embedding。

position weighted mean pooling是在mean pooling 里面加上了位置信息。

  • mean pooling
def mean_pooling(hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:if attention_mask is None:return torch.mean(hidden_state, dim=1)attention_mask = attention_mask.float()return torch.sum(hidden_state * attention_mask.unsqueeze(-1), dim=1) / torch.sum(attention_mask, dim=-1, keepdim=True)
  • position weighted mean pooling
def position_weighted_mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:weights = (torch.arange(start=1, end=last_hidden_state.shape[1] + 1).unsqueeze(0).unsqueeze(-1).expand(last_hidden_state.size()).float().to(last_hidden_state.device))input_mask_expanded = (attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float())# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dimsum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)sum_mask = torch.sum(input_mask_expanded * weights, dim=1)embeddings = sum_embeddings / sum_maskreturn embeddings

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/772292.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

集合(下)Map集合的使用

文章目录 前言一、Map接口二、Map接口的实现类 1.HashMap类2.TreeMap类总结 前言 Map集合没有继承Collection接口,不能像List集合和Set集合那样直接使用Collection接口的方法。Map集合其自身通过以key到value的映射关系实现的集合,也有相应的许多方法。类…

CUMT linux操作系统课程设计 任务2

先说题目: 调试Linux内核的启动过程,并在Linux 0.11内核进入保护模式之前添加提示信息 //这里吐槽一下,学校发的文档让你用断点去查看运行根本无法操作,报错如下: 所以别管这个报错,先跟着我来 第一题,调试…

各城市宗族文化姓氏占比数据

各城市宗族文化姓氏占比数据 1、指标:省份代码、所属省份、城市代码、所属城市、第1大姓氏、第2大姓氏、第3大姓氏、宗族文化强度 2、方法说明: 根据2005年全国1%的人口调查数据计算。其中第1大姓氏第一大姓人口数/总人口数,宗族文化强度(…

一文700字从0到1教你实现Jmeter分布式压力测试!

之前写过用jmeter做接口测试的文章,本篇我们继续介绍下用jmeter做分布式压力测试的例子。 用jmeter做压力测试,如果只用一台机器,有鉴于线程数的限制和一台机器的性能,可能无法满足压力测试的实际需求,解决这个问题&a…

力扣--并查集1631.最小体力消耗路径

这题将图论和并查集联系起来。把数组每个位置看成图中的一个节点。 这段代码的主要思路是: 遍历地图中的每个节点,将每个节点与其相邻的下方节点和右方节点之间的边加入到边集合中(因为从上到下和从下到上他们高度绝对值一样的,…

OpenHarmony之媒体组件模块简介

源码 本文基于OpenAtom OpenHarmony(以下简称“OpenHarmony”)3.2 Release源码foundation目录下的player_framework,在OpenHarmony 2.0 Release版本当中,这个模块的名字叫媒体组件模块,为了方便理解我们在本文中仍旧延…

VR全景展示:传统制造业如何保持竞争优势?

在结束不久的两会上,数字化经济和创新技术再度成为了热门话题。我国制造产业链完备,但是目前依旧面临着市场需求不足、成本传导压力加大等因素影响,那么传统制造业该如何保持竞争优势呢? 在制造行业中,VR全景展示的应用…

Query2doc——Query改写

大模型LLM最近一年比较火,但是可能由于数据量较大,存在一些矛盾的数据或者质量差的数据,就会导致大模型存在幻视情况,即存在严重不符合事实的情况。随着之而来,RAG(Retrieval Augmented Generation&#xf…

计算机组成原理(超详解!!) 第三节 运算器(浮点加减乘)

1.浮点加法、减法运算 操作过程 1.操作数检查 如果能够判断有一个操作数为0,则没必要再进行后续一系列操作,以节省运算时间。 2.完成浮点加减运算的操作 (1) 比较阶码大小并完成对阶 使二数阶码相同(即小数点位置对齐)…

windows@浏览器主页被篡改劫持@360篡改主页@广告和弹窗设置@极速版

文章目录 360篡改浏览器主页方法1锁定浏览器主页 方法2注册表修改 360广告和弹窗360极速版 小结 360篡改浏览器主页 如果您使用360,且不想卸载它,那么当你启动360后,它可能会篡改你的浏览器(比如edge)的主页start page为360早期可能是通过修改快捷方式的target等属性,但是现在…

《剑指 Offer》专项突破版 - 面试题 93 : 最长斐波那契数列(C++ 实现)

题目链接:最长斐波那契数列 题目: 输入一个没有重复数字的单调递增的数组,数组中至少有 3 个数字,请问数组中最长的斐波那契数列的长度是多少?例如,如果输入的数组是 [1, 2, 3, 4, 5, 6, 7, 8]&#xff0…

C++模版(基础)

目录 C泛型编程思想 C模版 模版介绍 模版使用 函数模版 函数模版基础语法 函数模版原理 函数模版实例化 模版参数匹配规则 类模版 类模版基础语法 C泛型编程思想 泛型编程:编写与类型无关的通用代码,是代码复用的一种手段。 模板是泛型编程…

【前端Vue】Vue3+Pinia小兔鲜电商项目第3篇:静态结构搭建和分类实现,1. 整体结构创建【附代码文档】

Vue3ElementPlusPinia开发小兔鲜电商项目完整教程(附代码资料)主要内容讲述:认识Vue3,使用create-vue搭建Vue3项目1. Vue3组合式API体验,2. Vue3更多的优势,1. 认识create-vue,2. 使用create-vue创建项目,1. setup选项的写法和执行…

【数据结构与算法】java有向带权图最短路径算法-Dijkstra算法(通俗易懂)

目录 一、什么是Dijkstra算法二、算法基本步骤三、java代码四、拓展(无向图的Dijkstra算法) 一、什么是Dijkstra算法 Dijkstra算法的核心思想是通过逐步逼近的方式,找出从起点到图中其他所有节点的最短路径。算法的基本步骤如下:…

应用层协议 - HTTP

文章目录 目录 文章目录 前言 1 . 应用层概要 2. WWW 2.1 互联网的蓬勃发展 2.2 WWW基本概念 2.3 URI 3 . HTTP 3.1 工作过程 3.2 HTTP协议格式 3.3 HTTP请求 3.3.1 URL基本格式 3.3.2 认识方法 get方法 post方法 其他方法 3.3.2 认识请求报头 3.3.3 认识请…

MyBatis是纸老虎吗?(七)

在上篇文章中,我们对照手动编写jdbc的开发流程,对MyBatis进行了梳理。通过这次梳理我们发现了一些之前文章中从未见过的新知识,譬如BoundSql等。本节我想继续MyBatis这个主题,并探索一下MyBatis中的缓存机制。在正式开始梳理前&am…

如何解决kafka rebalance导致的暂时性不能消费数据问题

文章目录 背景思考答案排它故障转移共享 背景 之前在review同组其它业务的时候,发现竟然把kafka去掉了,问了下原因,有一个单独的服务,我们可以把它称为agent,就是这个服务是动态扩缩容的,会采集一些指标&a…

使用C++实现一个简单的日志功能

日志对于一些大一些的项目来说,可以在项目运行出现问题时更好的帮助 项目的维护人员快速的定位到问题出现的地方并且知道出现问题的原因, 并且日志也可以帮助程序员很好的进行项目的Debug,那么今天我就来实 现一个C编写的一个简单的日志功能。…

深度学习中常用计算距离的几种算法对比与python实现

前言 距离度量在许多机器学习算法中扮演着至关重要的角色,无论是监督学习还是无监督学习。选择适当的距离度量可以显著影响模型的性能。 在高维数据集中,欧几里得距离可能会受到所谓的“维度诅咒”的影响,因为随着维度的增加,数…

海外媒体软文发稿:谷歌关键词优化细分人群成功案例,突破海外市场!

海外媒体软文发稿:谷歌关键词优化细分人群成功案例,突破海外市场! 引言 在全球化的时代,海外市场对于企业的发展至关重要。而在海外市场中,互联网媒体的作用不可忽视。本篇教程将介绍如何通过谷歌关键词优化细分人群…