NLP文本匹配任务Text Matching [无监督训练]:SimCSE、ESimCSE、DiffCSE 项目实践

NLP文本匹配任务Text Matching [无监督训练]:SimCSE、ESimCSE、DiffCSE 项目实践

文本匹配多用于计算两个文本之间的相似度,该示例会基于 ESimCSE 实现一个无监督的文本匹配模型的训练流程。文本匹配多用于计算两段「自然文本」之间的「相似度」。

例如,在搜索引擎中,我们通常需要判断用户的搜索内容是否相似:

A:蛋黄吃多了有什么坏处    B:吃鸡蛋白过多有什么坏处  ->  不相似
A:蛋黄吃多了有什么坏处    B:蛋黄可以多吃吗         ->  相似
...

那最直觉的思路就是让人工去标注文本对,再喂给模型去学习,这种方法称为基于「监督学习」训练出的模型:

但是,如果我们今天没有这么多的标注数据,只有一大堆的「未标注」数据,我们还能训练一个匹配模型吗?这种不依赖于「人工标注数据」的方式,就叫做「无监督」(或自监督)学习方式。我们今天要讲的 SimCSE, 就是一种「无监督」训练模型。

SimCSE: Simple Contrastive Learning of Sentence Embeddings

1.SimCSE 是如何做到无监督的?

SimCSE 将对比学习(Contrastive Learning)的思想引入到文本匹配中。对比学习的核心思想就是:将相似的样本拉近,将不相似的样本推远

但现在问题是:我们没有标注数据,怎么知道哪些文本是相似的,哪些是不相似的呢?SimCSE 相出了一种很妙的办法,由于预训练模型在训练的时候通常都会使用 dropout 机制。这就意味着:即使是同一个样本过两次模型也会得到两个不同的 embedding。而因为同样的样本,那一定是相似的,模型输出的这两个 embedding 距离就应当尽可能的相近;反之,那些不同的输入样本过模型后得到的 embedding 就应当尽可能的被推远。

具体来讲,一个 batch 内每个句子会过 2 次模型,得到 2 * batch 个向量,将这些句子中通过同样句子得到的向量设置为正例,其他设置为负例。

假设 a1 和 a2 是由句子 a 过两次模型得到的结果,那么一个 batch 内的正负例构建如下所示:

a1a2b1b2c1c2
a1-10010000
a21-1000000
b100-100100
b2001-10000
c10000-1001
c200001-100

其中,对角线上的 - 100 表示自身和自身不做相似度比较。

2. SimCSE 的缺点?

从 SimCSE 的正例构建中我们可以看出来,所有的正例都是由「同一个句子」过了两次模型得到的。这就会造成一个问题:模型会更倾向于认为,长度相同的句子就代表一样的意思。由于数据样本是随机选取的,那么很有可能在一个 batch 内采样到的句子长度是不相同的。

为了解决这个问题,我们最终采取的实现方式为 ESimCSE

3. ESimCSE 解决模型对文本长度的敏感问题

ESimCSE 通过随机重复单词(Word Repetition)的方式来构建正例,巧妙的解决了句子长度敏感性的问题:

ESimCSE: Enhanced Sample Building Method for Contrastive Learning of Unsupervised Sentence Embedding

要想消除模型对句子长度的敏感,我们就需要在构建正例的时候让输入句子的长度发生改变,如下所示:

那么,改变句子长度通常有 3 种方法:随机删除、随机添加、同义词替换,但它们均存在句意变化的风险:

方法原句子变换后的句子句意是否改变
随机删除我 [不] 喜欢你我喜欢你
随机添加今天的饭好吃今天的饭 [不] 好吃
同义词替换小明长得像一只 [狼]小明长得像一只 [狗]

用语义变换后的句子去构建正例,模型效果自然会受到影响。

那如果我们随机重复一些单词呢?

方法原句子变换后的句子句意是否改变
随机重复单词今天天气很好今今天天气很好好
随机重复单词我喜欢你我我喜欢欢你

可以看到,通过随机重复单词,既能够改变句子长度,又不会轻易改变语义。

实现上,假设我们有一个 batch 的句子,我们先依次将每一个句子都进行随机单词重复(产生正例),如下:

origin ->     ['人和畜生的区别', '今天天气很好', '三星手机屏幕是不是最好的?']
repetition -> ['人人和畜生的的区别', '今今天天气很好好', '三星星手机屏屏幕是不是最最好好的?']

随后,我们将 origin 的 embedding(batch,768) 和 repetition 的 embedding(batch,768)做矩阵乘法,可以得到一个矩阵(batch,batch),矩阵对角线上就是正例,其余的均是负例:

句子 a句子 b句子 c
句子 a0.92480.23420.4242
句子 b0.31420.91230.1422
句子 c0.29030.18570.9983

矩阵中第(i,j)个元素代表 origin 列表中的第 i 个元素和 repetition 列表中第 j 个元素的相似度。

接下来就好构建训练标签了,因为 label 都在对角线上,所以第 n 行的 label 就是 n 。

labels = [i for i in range(len(origin))]     # labels = [0, 1, 2]

之后就用 CrossEntropyLoss 去计算并梯度回传就能开始训练啦。

def forward(self,query_input_ids: torch.tensor,query_token_type_ids: torch.tensor,doc_input_ids: torch.tensor,doc_token_type_ids: torch.tensor,device='cpu') -> torch.tensor:"""传入query/doc对,构建正/负例并计算contrastive loss。Args:query_input_ids (torch.LongTensor): (batch, seq_len)query_token_type_ids (torch.LongTensor): (batch, seq_len)doc_input_ids (torch.LongTensor): (batch, seq_len)doc_token_type_ids (torch.LongTensor): (batch, seq_len)device (str): 使用设备Returns:torch.tensor: (1)"""query_embedding = self.get_pooled_embedding(input_ids=query_input_ids,token_type_ids=query_token_type_ids)                                                           # (batch, self.output_embedding_dim)doc_embedding = self.get_pooled_embedding(input_ids=doc_input_ids,token_type_ids=doc_token_type_ids)                                                           # (batch, self.output_embedding_dim)cos_sim = torch.matmul(query_embedding, doc_embedding.T)    # (batch, batch)margin_diag = torch.diag(torch.full(                        # (batch, batch), 只有对角线等于margin值的对角矩阵size=[query_embedding.size()[0]], fill_value=self.margin)).to(device)cos_sim = cos_sim - margin_diag                             # 主对角线(正例)的余弦相似度都减掉 margincos_sim *= self.scale                                       # 缩放相似度,便于收敛labels = torch.arange(                                      # 只有对角上为正例,其余全是负例,所以这个batch样本标签为 -> [0, 1, 2, ...]0, query_embedding.size()[0], dtype=torch.int64).to(device)loss = self.criterion(cos_sim, labels)return loss

4.DiffCSE

结合句子间差异的无监督句子嵌入对比学习方法——DiffCSE主要还是在SimCSE上进行优化(可见SimCSE的重要性),通过ELECTRA模型的生成伪造样本和RTD(Replaced Token Detection)任务,来学习原始句子与伪造句子之间的差异,以提高句向量表征模型的效果。

其思想同样来自于CV领域(采用不变对比学习和可变对比学习相结合的方法可以提高图像表征的效果)。作者提出使用基于dropout masks机制的增强作为不敏感转换学习对比学习损失和基于MLM语言模型进行词语替换的方法作为敏感转换学习「原始句子与编辑句子」之间的差异,共同优化句向量表征。

在SimCSE模型中,采用pooler层(一个带有tanh激活函数的全连接层)作为句子向量输出。该论文发现,采用带有BN的两层pooler效果更为突出,BN在SimCSE模型上依然有效。

①对于掩码概率,经实验发现,在掩码概率为30%时,模型效果最优。
②针对两个损失之间的权重值,经实验发现,对比学习损失为RTD损失200倍时,模型效果最优。

参考链接:https://blog.csdn.net/PX2012007/article/details/127696477

5. 数据集准备

项目中提供了一部分示例数据,我们使用未标注的用户搜索记录数据来训练一个文本匹配模型,数据在 data/LCQMC

若想使用自定义数据训练,只需要仿照示例数据构建数据集即可:

  • 训练集:
喜欢打篮球的男生喜欢什么样的女生
我手机丢了,我想换个手机
大家觉得她好看吗
晚上睡觉带着耳机听音乐有什么害处吗?
学日语软件手机上的
...
  • 测试集:
开初婚未育证明怎么弄?	初婚未育情况证明怎么开?	1
谁知道她是网络美女吗?	爱情这杯酒谁喝都会醉是什么歌	0
人和畜生的区别是什么?	人与畜生的区别是什么!	1
男孩喝女孩的尿的故事	怎样才知道是生男孩还是女孩	0
...

由于是无监督训练,因此训练集(train.txt)中不需要记录标签,只需要大量的文本即可。

测试集(dev.tsv)用于测试无监督模型的效果,因此需要包含真实标签。

每一行用 \t 分隔符分开,第一部分部分为句子A,中间部分为句子B,最后一部分为两个句子是否相似(label)

6.模型训练

修改训练脚本 train.sh 里的对应参数, 开启模型训练:

python train.py \--model "nghuyong/ernie-3.0-base-zh" \--train_path "data/LCQMC/train.txt" \--dev_path "data/LCQMC/dev.tsv" \--save_dir "checkpoints/LCQMC" \--img_log_dir "logs/LCQMC" \--img_log_name "ERNIE-ESimCSE" \--learning_rate 1e-5 \--dropout 0.3 \--batch_size 64 \--max_seq_len 64 \--valid_steps 400 \--logging_steps 50 \--num_train_epochs 8 \--device "cuda:0"

正确开启训练后,终端会打印以下信息:

...
0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 226.41it/s]
DatasetDict({train: Dataset({features: ['text'],num_rows: 477532})dev: Dataset({features: ['text'],num_rows: 8802})
})
global step 50, epoch: 1, loss: 0.34367, speed: 2.01 step/s
global step 100, epoch: 1, loss: 0.19121, speed: 2.02 step/s
global step 150, epoch: 1, loss: 0.13498, speed: 2.00 step/s
global step 200, epoch: 1, loss: 0.10696, speed: 1.99 step/s
global step 250, epoch: 1, loss: 0.08858, speed: 2.02 step/s
global step 300, epoch: 1, loss: 0.07613, speed: 2.02 step/s
global step 350, epoch: 1, loss: 0.06673, speed: 2.01 step/s
global step 400, epoch: 1, loss: 0.05954, speed: 1.99 step/s
Evaluation precision: 0.58459, recall: 0.87210, F1: 0.69997, spearman_corr: 
0.36698
best F1 performence has been updated: 0.00000 --> 0.69997
global step 450, epoch: 1, loss: 0.25825, speed: 2.01 step/s
global step 500, epoch: 1, loss: 0.27889, speed: 1.99 step/s
global step 550, epoch: 1, loss: 0.28029, speed: 1.98 step/s
global step 600, epoch: 1, loss: 0.27571, speed: 1.98 step/s
global step 650, epoch: 1, loss: 0.26931, speed: 2.00 step/s
...

logs/LCQMC 文件下将会保存训练曲线图:

7.模型推理

完成模型训练后,运行 inference.py 以加载训练好的模型并应用:

...if __name__ == '__main__':...sentence_pair = [('男孩喝女孩的故事', '怎样才知道是生男孩还是女孩'),('这种图片是用什么软件制作的?', '这种图片制作是用什么软件呢?')]...res = inference(query_list, doc_list, model, tokenizer, device)print(res)

运行推理程序:

python inference.py

得到以下推理结果:

[0.1527191698551178, 0.9263839721679688]   # 第一对文本相似分数较低,第二对文本相似分数较高

参考链接:https://github.com/HarderThenHarder/transformers_tasks/blob/main/text_matching/supervised

github无法连接的可以在:https://download.csdn.net/download/sinat_39620217/88214437 下载

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/36430.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一百五十三、Kettle——Linux上安装的kettle9.3启动后说缺少libwebkitgtk-1.0(真是坑爹啊,刚龟速下载又忍痛卸载)

一、问题 在kettle9.3可以在本地连接hive312后&#xff0c;在Linux中安装了kettle9.3&#xff0c;结果启动时报错WARNING: no libwebkitgtk-1.0 detected, some features will be unavailable 而且如果直接下载libwebkitgtk的话也没有用 [roothurys22 data-integration]# yu…

在线吉他调音

先看效果&#xff08;图片没有声&#xff0c;可以下载源码看看&#xff0c;比这更好~&#xff09;&#xff1a; 再看代码&#xff08;查看更多&#xff09;&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8&quo…

【第二阶段】kotlin的函数类型作为返回类型

fun main() {//调用,返回的是一个匿名类型&#xff0c;所以info就是一个匿名函数val infoshow("",0)//info接受的返回值为匿名类型&#xff0c;此时info就是一个匿名函数println(info("kotlin",20)) }//返回类型为一个匿名函数的返回类型fun show(name:Str…

01 - 工作区、暂存区、版本库、远程仓库 - 以一次连贯的提交操作为例

查看所有文章链接&#xff1a;&#xff08;更新中&#xff09;GIT常用场景- 目录 文章目录 1. 工作区、暂存区、版本库、远程仓库1.1 工作区1.2 工作区 > 暂存区&#xff1a;git add1.3 暂存区 > 版本库&#xff1a;git commit1.4 push到远程仓库 1. 工作区、暂存区、版本…

【生成式AI】ProlificDreamer论文阅读

ProlificDreamer 论文阅读 Project指路&#xff1a;https://ml.cs.tsinghua.edu.cn/prolificdreamer/ 论文简介&#xff1a;截止2023/8/10&#xff0c;text-to-3D的baseline SOTA&#xff0c;提出了VSD优化方法 前置芝士:text-to-3D任务简介 text-to-3D Problem text-to-3D…

解决校园网使用vmware桥接模式,虚拟机与物理机互相ping通,但是虚拟机ping不通百度的问题

遇到的问题 使用校园网时&#xff0c;桥接模式下&#xff0c;物理机可以ping通虚拟机&#xff0c;但是虚拟机ping不通主机 解决方法 在物理机中查看网络相关信息 ipconfig 修改虚拟机网卡信息 vim /etc/sysconfig/network-scripts/ifcfg-ens33 注意 /ifcfg-ens33需要根据…

C++ QT(一)

目录 初识QtQt 是什么Qt 能做什么Qt/C与QML 如何选择Qt 版本Windows 下安装QtLinux 下安装Qt安装Qt配置Qt Creator 输入中文配置Ubuntu 中文环境配置中文输入法 Qt Creator 简单使用Qt Creator 界面组成Qt Creator 设置 第一个Qt 程序新建一个项目项目文件介绍项目文件*.pro样式…

SpringBoot启动报错:java: 无法访问org.springframework.boot.SpringApplication

报错原因&#xff1a;jdk 1.8版本与SpringBoot 3.1.2版本不匹配 解决方案&#xff1a;将SpringBoot版本降到2系列版本(例如2.5.4)。如下图&#xff1a; 修改版本后切记刷新Meavn依赖 然后重新启动即可成功。如下图&#xff1a;

3.4 网络安全管理设备

数据参考&#xff1a;CISP官方 目录 IDS (入侵检测系统)网络安全审计漏洞扫描系统VPN&#xff08;虚拟专网&#xff09;堡垒主机安全管理平台 一、IDS (入侵检测系统) 入侵检测系统&#xff08;IDS&#xff09;是一种网络安全设备&#xff0c;用于监测和检测网络中的入侵行…

树莓派3B CSI摄像头配置

1.硬件连接 1、找到 CSI 接口(树莓派3B的CSI接口在HDMI接口和音频口中间)&#xff0c;需要拉起 CSI 接口挡板,如下&#xff1a; 2、将摄像头排线插入CSI接口。记住&#xff0c;有蓝色胶带的一面应该面向音频口或者网卡方向&#xff0c; 确认方向并插紧排线&#xff0c;将挡板…

计算机竞赛 python 机器视觉 车牌识别 - opencv 深度学习 机器学习

1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于python 机器视觉 的车牌识别系统 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;3分 &#x1f9ff; 更多资…

【设计模式】建造者模式

建造者模式&#xff08;Builder Pattern&#xff09;使用多个简单的对象一步一步构建成一个复杂的对象。这种类型的设计模式属于创建型模式&#xff0c;它提供了一种创建对象的最佳方式。 一个 Builder 类会一步一步构造最终的对象。该 Builder 类是独立于其他对象的。 介绍 …

LeetCode 31题:下一个排列

目录 题目 思路 代码 题目 整数数组的一个 排列 就是将其所有成员以序列或线性顺序排列。 例如&#xff0c;arr [1,2,3] &#xff0c;以下这些都可以视作 arr 的排列&#xff1a;[1,2,3]、[1,3,2]、[3,1,2]、[2,3,1] 。 整数数组的 下一个排列 是指其整数的下一个字典序…

Flink 火焰图

方式一 使用 Flink Web UI 的 Flame Graph Flink 自己也支持了 Task 粒度的 Flame Graphs 功能&#xff0c;并且可以细化到 subtask 粒度。 第一步&#xff1a;配置启用功能 Flink 作业动态参数里增加配置&#xff1a;“rest.flamegraph.enabled”: “true” 并重启作业。当前…

Blazor 简单组件(0):简单介绍

文章目录 前言说明环境安装 前言 Blazor 这个技术还是比较新&#xff0c;相关的UI组件还在完善&#xff0c;我这里提供一下我个人的组件开发。 说明 本UI组件是基于BootstrapBlazor(以下简称BB)开发。 BootstrapBlazor 文档 环境安装 C#小轮子&#xff1a;Visual Studio自…

C语言快速回顾(二)

前言 在Android音视频开发中&#xff0c;网上知识点过于零碎&#xff0c;自学起来难度非常大&#xff0c;不过音视频大牛Jhuster提出了《Android 音视频从入门到提高 - 任务列表》&#xff0c;结合我自己的工作学习经历&#xff0c;我准备写一个音视频系列blog。C/C是音视频必…

目前有哪些好用的免费开源wms仓储管理软件?

什么是开源&#xff1f; 开源指的是软件的源代码是公开可见和可自由使用的。开源软件的授权许可通常允许用户查看、修改和分发源代码&#xff0c;以及根据自己的需求进行定制和扩展。 开源工具的核心理念是共享和协作。通过开放源代码&#xff0c;开源软件鼓励用户之间的合作…

Tubi 前端测试:迁移 Enzyme 到 React Testing Library

前端技术发展迅速&#xff0c;即便不说是日新月异&#xff0c;每年也都推出新框架和新技术。Tubi 的产品前端代码仓库始建于 2015 年&#xff0c;至今 8 年有余。可喜的是&#xff0c;多年来紧随 React 社区的发展&#xff0c;Tubi 绝大多数的基础框架选型都遵循了社区流行的最…

CentOS-6.3安装MySQL集群

安装要求 安装环境&#xff1a;CentOS-6.3 安装方式&#xff1a;源码编译安装 软件名称&#xff1a;mysql-cluster-gpl-7.2.6-linux2.6-x86_64.tar.gz 下载地址&#xff1a;http://mysql.mirror.kangaroot.net/Downloads/ 软件安装位置&#xff1a;/usr/local/mysql 数据存放位…

达梦数据库(dm8) Centos7 高可用集群

国产数据库-达梦 一、环境详情二、Centos7 参数优化&#xff08;所有节点&#xff09;三、创建用户&#xff08;所有节点&#xff09;四、开始安装&#xff08;所有节点&#xff09;五、服务注册启动 当前安装&#xff1a;在指定版本环境下 测试&#xff0c;仅供参考 官网描述&…