bert 相似度任务训练,简单版本

目录

任务

代码

train.py

predit.py

数据


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5# 准备数据集
class MyDataset(Dataset):def __init__(self, data_file_paths):self.texts = []self.labels = []# 分词器用默认的self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')# 自己实现对数据集的解析with open(data_file_paths, 'r', encoding='utf-8') as f:for line in f:text1, text2, label = line.split('\t')self.texts.append((text1, text2))self.labels.append(int(label))def __len__(self):return len(self.texts)def __getitem__(self, idx):text1, text2 = self.texts[idx]label = self.labels[idx]encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')return encoded_text, label# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0for epoch in range(epoch):trained_data = 0for batch in train_loader:inputs, labels = batch# 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉inputs['input_ids'] = inputs['input_ids'].squeeze(1)inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)# 因为要用GPU,将数据传输到gpu上inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(**inputs, labels=labels)loss, logits = outputs[:2]loss.backward()optimizer.step()trained_data += len(labels)trained_process = float(trained_data) / len(train_dataset)batch_after_last_save += 1total_batch += 1# 每训练 auto_save_batch 个 batch,保存一次模型if batch_after_last_save >= auto_save_batch:batch_after_last_save = 0model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))total_epoch += 1model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePredictiontokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')with torch.no_grad():with open('../data/test.txt', 'r', encoding='utf8') as f:lines = f.readlines()correct = 0for i, line in enumerate(lines):text1, text2, label = line.split('\t')encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')outputs = model(**encoded_text)res = torch.argmax(outputs.logits, dim=1).item()print(text1, text2, label, res)if str(res) == label.strip('\n'):correct += 1print(f'{i + 1}/{len(lines)}')print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/712869.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

thinkphp学习10-数据库的修改删除

数据修改 使用 update()方法来修改数据,修改成功返回影响行数,没有修改返回 0 public function index(){$data [username > 孙悟空1,];return Db::name(user)->where(id,11)->update($data);}如果修改数据包含了主键信息,比如 i…

STM32标准库开发——BKP备份RTC时钟

备份寄存器BKP(Backup Registers) 由于RTC与BKP关联性较高,所以RTC的时钟校准寄存器以及一些功能都放在了BKP中。TAMPER引脚主要用于防止芯片数据泄露,可以设计一个机关当TAMPER引脚发生电平跳变时自动清除寄存器内数据不同芯片BKP区别,主要体…

c++入门(2)

上期我们说到了部分c修补C语言的不足,今天我们将剩下的一一说清楚。 函数重载 (1).函数重载的形式 C语言不允许函数名相同的同时存在,但是C允许同名函数存在,但是有要求:函数名相同,参数不同,构成函数重…

【数据结构-图论】并查集

并查集(Union-Find)是一种数据结构,它提供了处理一些不交集的合并及查询问题的高效方法。并查集主要支持两种操作: 查找(Find):确定某个元素属于哪个子集,这通常意味着找到该子集的…

vue购物车实战

1.引入vue <script src"https://cdn.jsdelivr.net/npm/vue2.7.14/dist/vue.js"></script> 2.设置高亮部分的样式 <style> table tr{text-align: center;}.skyblue{background-color: skyblue;}</style> 3.设置body的基本样式 <div id&q…

人大金仓与mysql的差异与替换

人大金仓中不能使用~下面的符号&#xff0c;字段中使用”&#xff0c;无法识别建表语句 创建表时语句中只定义字段名.字段类型.是否是否为空 Varchar类型改为varchar&#xff08;长度 char&#xff09; Int(0) 类型为int4 定义主键&#xff1a;CONSTRAINT 键名 主键类型&#x…

Found option without preceding group in config file 问题解决

方法就是用记事本打开 然后 左上角点击 文件 有另存为 就可以选择编码格式

Linux设置程序任意位置执行(设置环境变量)

问题 直接编译出来的可执行程序在执行时需要写出完整路径比较麻烦&#xff0c;设置环境变量可以实现在任意位置直接运行。 解决 1.打开.bashrc文件 vim ~/.bashrc 2.修改该文件&#xff08;实现将/home/zhangziheng/file/seqrequester/build/bin&#xff0c;路径下的可执…

文件流【文件输入流】

文件流&#xff1a;使用文件输入流读取文件中的数据&#xff1a; public class FISDemo {public static void main(String[] args) throws IOException {//将fos.dat文件中的字节读取回来/*fos.dat文件中的数据:00000001 00000010*/FileInputStream fis new FileInputStream(…

第六节:Vben Admin权限-后端控制方式

系列文章目录 第一节:Vben Admin介绍和初次运行 第二节:Vben Admin 登录逻辑梳理和对接后端准备 第三节:Vben Admin登录对接后端login接口 第四节:Vben Admin登录对接后端getUserInfo接口 第五节:Vben Admin权限-前端控制方式 文章目录 系列文章目录前言一、角色权限(后端…

Java面试题总结6

Spring中有哪些设计模式 简单工厂&#xff1a;由一个工厂类根据传入的参数&#xff0c;动态决定应该创建哪一个产品类 工厂方法&#xff1a;实现FactoryBean接口的bean是一类叫做factory的bean 单例模式&#xff1a;保证一个类仅有一个实例&#xff0c;并提供一个访问它的全…

【办公类-18-03】(Python)中班米罗可儿证书批量生成打印(班级、姓名)

作品展示——米罗可儿证书打印幼儿姓名 背景需求 2024年3月1日&#xff0c;中4班孩子一起整理美术操作材料《米罗可儿》的操作本——将每一页纸撕下来&#xff0c;分类摆放、确保纸张上下位置正确。每位孩子们都非常厉害&#xff0c;不仅完成了自己的一本&#xff0c;还将没有…

C++数据结构与算法——二叉搜索树的属性

C第二阶段——数据结构和算法&#xff0c;之前学过一点点数据结构&#xff0c;当时是基于Python来学习的&#xff0c;现在基于C查漏补缺&#xff0c;尤其是树的部分。这一部分计划一个月&#xff0c;主要利用代码随想录来学习&#xff0c;刷题使用力扣网站&#xff0c;不定时更…

C++编程面试复盘:数组降重+快排+函数指针+类模板

面试真题 真题1 #include <iostream> // 在函数 find_repetnum 的参数列表中&#xff0c;int& length 中的 & 符号表示引用。通过将 length 声明为引用&#xff0c;函数可以修改传入的 length 变量的值&#xff0c;并且这种修改会在函数外部生效。 void find_r…

Vue2:路由history模式的项目部署后页面刷新404问题处理

一、问题描述 我们把Vue项目的路由模式&#xff0c;设置成history 然后&#xff0c;build 并把dist中的代码部署到nodeexpress服务中 访问页面后&#xff0c;刷新页面报404问题 二、原因分析 server.js文件 会发现&#xff0c;文件中配置的路径没有Vue项目中对应的路径 所以…

React withRouter的使用及源码实现

一 基本介绍 作用&#xff1a; 把不是通过路由切换过来的组件中&#xff0c;将react-router 的 history、location、match 三个对象传入props对象上。比如首页&#xff01; 默认情况下必须是经过路由匹配渲染的组件才存在this.props&#xff0c;才拥有路由参数&#xff0c;才能…

嵌入式学习笔记Day27

今天主要学习了进程间的通信&#xff0c;主要学习了通过管道进行通信 一、进程间的通信 进程间的通信方式有以下几种&#xff1a; 1.管道 2.信号 3.消息队列 4.共享内存 5.信号灯 6.套接字二、管道 2.1无名管道 无名管道只能用于具有亲缘关系的进程间通信 函数接口&#x…

Nacos进阶

目录 Nacos支持三种配置加载方案 Namespace方案 DataID方案 Group方案 同时加载多个配置集 Nacos支持三种配置加载方案 Nacos支持“Namespacegroupdata ID”的配置解决方案。 详情见&#xff1a;Nacos config alibaba/spring-cloud-alibaba Wiki GitHub Namespace方案…

《TCP/IP详解 卷一》第12章 TCP初步介绍

目录 12.1 引言 12.1.1 ARQ和重传 12.1.2 滑动窗口 12.1.3 变量窗口&#xff1a;流量控制和拥塞控制 12.1.4 设置重传的超时值 12.2 TCP的引入 12.2.1 TCP服务模型 12.2.2 TCP可靠性 12.3 TCP头部和封装 12.4 总结 12.1 引言 关于TCP详细内容&#xff0c;原书有5个章…

【C++ map和set】

文章目录 map和set序列式容器和关联式容器键值对setset的主要操作 mapmap主要操作 multiset和multimap map和set 序列式容器和关联式容器 之前我们接触的vector,list,deque等&#xff0c;这些容器统称为序列式容器&#xff0c;其底层为线性序列的的数据结构&#xff0c;里面存…