bert 相似度任务训练完整版

任务

之前写了一个相似度任务的版本:bert 相似度任务训练简单版本,faiss 寻找相似 topk-CSDN博客

相似度用的是 0,1,相当于分类任务,现在我们相似度有评分,不再是 0,1 了,分数为 0-5,数字越大代表两个句子越相似,这一次的比较完整,评估,验证集,相似度模型都有了。

数据集

链接:https://pan.baidu.com/s/1B1-PKAKNoT_JwMYJx_zT1g 
提取码:er1z 
原始数据好几千条,我训练数据用了部分 2500 条,验证,测试 300 左右,使用 cpu 也用了好几个小时

train.py

import torch
import os
import time
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AdamW, get_cosine_schedule_with_warmup
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np# 设备选择
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'# 定义文本相似度数据集类
class TextSimilarityDataset(Dataset):def __init__(self, file_path, tokenizer, max_len=128):self.data = []with open(file_path, 'r', encoding='utf-8') as f:for line in f.readlines():text1, text2, similarity_score = line.strip().split('\t')inputs1 = tokenizer(text1, padding='max_length', truncation=True, max_length=max_len)inputs2 = tokenizer(text2, padding='max_length', truncation=True, max_length=max_len)self.data.append({'input_ids1': inputs1['input_ids'],'attention_mask1': inputs1['attention_mask'],'input_ids2': inputs2['input_ids'],'attention_mask2': inputs2['attention_mask'],'similarity_score': float(similarity_score),})def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]def cosine_similarity_torch(vec1, vec2, eps=1e-8):dot_product = torch.mm(vec1, vec2.t())norm1 = torch.norm(vec1, 2, dim=1, keepdim=True)norm2 = torch.norm(vec2, 2, dim=1, keepdim=True)similarity_scores = dot_product / (norm1 * norm2.t()).clamp(min=eps)return similarity_scores# 定义模型,这里我们不仅计算两段文本的[CLS] token的点积,而是整个句向量的余弦相似度
class BertSimilarityModel(torch.nn.Module):def __init__(self, pretrained_model):super(BertSimilarityModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model)self.dropout = torch.nn.Dropout(p=0.1)  # 引入Dropout层以防止过拟合def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):embeddings1 = self.dropout(self.bert(input_ids=input_ids1, attention_mask=attention_mask1)['last_hidden_state'])embeddings2 = self.dropout(self.bert(input_ids=input_ids2, attention_mask=attention_mask2)['last_hidden_state'])# 计算两个文本向量的余弦相似度embeddings1 = torch.mean(embeddings1, dim=1)embeddings2 = torch.mean(embeddings2, dim=1)similarity_scores = cosine_similarity_torch(embeddings1, embeddings2)# 映射到[0, 5]评分范围normalized_similarities = (similarity_scores + 1) * 2.5return normalized_similarities.unsqueeze(1)# 自定义损失函数,使用Smooth L1 Loss,更适合处理回归问题
class SmoothL1Loss(torch.nn.Module):def __init__(self):super(SmoothL1Loss, self).__init__()def forward(self, predictions, targets):diff = predictions - targetsabs_diff = torch.abs(diff)quadratic = torch.where(abs_diff < 1, 0.5 * diff ** 2, abs_diff - 0.5)return torch.mean(quadratic)def train_model(model, train_loader, val_loader, epochs=3, model_save_path='../output/bert_similarity_model.pth'):model.to(device)criterion = SmoothL1Loss()  # 使用自定义的Smooth L1 Lossoptimizer = AdamW(model.parameters(), lr=5e-5)  # 调整初始学习率为5e-5num_training_steps = len(train_loader) * epochsscheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.1*num_training_steps, num_training_steps=num_training_steps)  # 使用带有warmup的余弦退火学习率调度best_val_loss = float('inf')for epoch in range(epochs):model.train()for batch in train_loader:input_ids1 = batch['input_ids1'].to(device)attention_mask1 = batch['attention_mask1'].to(device)input_ids2 = batch['input_ids2'].to(device)attention_mask2 = batch['attention_mask2'].to(device)similarity_scores = batch['similarity_score'].to(device)optimizer.zero_grad()outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)loss = criterion(outputs, similarity_scores.unsqueeze(1))loss.backward()optimizer.step()scheduler.step()# 验证阶段model.eval()with torch.no_grad():val_loss = 0total_val_samples = 0for batch in val_loader:input_ids1 = batch['input_ids1'].to(device)attention_mask1 = batch['attention_mask1'].to(device)input_ids2 = batch['input_ids2'].to(device)attention_mask2 = batch['attention_mask2'].to(device)similarity_scores = batch['similarity_score'].to(device)val_outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)val_loss += criterion(val_outputs, similarity_scores.unsqueeze(1)).item()total_val_samples += len(similarity_scores)val_loss /= len(val_loader)print(f'Epoch {epoch + 1}, Validation Loss: {val_loss:.4f}')if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), model_save_path)def collate_to_tensors(batch):'''把数据处理为模型可用的数据,不同任务可能需要修改一下,'''input_ids1 = torch.tensor([example['input_ids1'] for example in batch])attention_mask1 = torch.tensor([example['attention_mask1'] for example in batch])input_ids2 = torch.tensor([example['input_ids2'] for example in batch])attention_mask2 = torch.tensor([example['attention_mask2'] for example in batch])similarity_score = torch.tensor([example['similarity_score'] for example in batch])return {'input_ids1': input_ids1, 'attention_mask1': attention_mask1, 'input_ids2': input_ids2,'attention_mask2': attention_mask2, 'similarity_score': similarity_score}# 加载数据集和预训练模型
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertSimilarityModel('../bert-base-chinese')# 加载数据并创建
train_data = TextSimilarityDataset('../data/STS-B/STS-B.train - 副本.data', tokenizer)
val_data = TextSimilarityDataset('../data/STS-B/STS-B.valid - 副本.data', tokenizer)
test_data = TextSimilarityDataset('../data/STS-B/STS-B.test - 副本.data', tokenizer)train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_to_tensors)
val_loader = DataLoader(val_data, batch_size=32, collate_fn=collate_to_tensors)
test_loader = DataLoader(test_data, batch_size=32, collate_fn=collate_to_tensors)optimizer = AdamW(model.parameters(), lr=2e-5)# 开始训练
train_model(model, train_loader, val_loader)# 加载最佳模型进行测试
model.load_state_dict(torch.load('../output/bert_similarity_model.pth'))
test_loss = 0
total_test_samples = 0with torch.no_grad():for batch in test_loader:input_ids1 = batch['input_ids1'].to(device)attention_mask1 = batch['attention_mask1'].to(device)input_ids2 = batch['input_ids2'].to(device)attention_mask2 = batch['attention_mask2'].to(device)similarity_scores = batch['similarity_score'].to(device)test_outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)test_loss += torch.nn.functional.mse_loss(test_outputs, similarity_scores.unsqueeze(1)).item()total_test_samples += len(similarity_scores)test_loss /= len(test_loader)
print(f'Test Loss: {test_loss:.4f}')

predit.py

这个脚本是用来看看效果的,直接传入两个文本,使用训练好的模型来计算相似度的

import torch
from transformers import BertTokenizer, BertModeldef cosine_similarity_torch(vec1, vec2, eps=1e-8):dot_product = torch.mm(vec1, vec2.t())norm1 = torch.norm(vec1, 2, dim=1, keepdim=True)norm2 = torch.norm(vec2, 2, dim=1, keepdim=True)similarity_scores = dot_product / (norm1 * norm2.t()).clamp(min=eps)return similarity_scores# 定义模型,这里我们不仅计算两段文本的[CLS] token的点积,而是整个句向量的余弦相似度
class BertSimilarityModel(torch.nn.Module):def __init__(self, pretrained_model):super(BertSimilarityModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model)self.dropout = torch.nn.Dropout(p=0.1)  # 引入Dropout层以防止过拟合def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):'''如果是用来预测,forward 会被禁用'''pass# 加载预训练模型和分词器
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertSimilarityModel('../bert-base-chinese')
model.load_state_dict(torch.load('../output/bert_similarity_model.pth'))  # 请确保路径正确
model.eval()  # 设置模型为评估模式def calculate_similarity(text1, text2):# 对输入文本进行编码inputs1 = tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt')inputs2 = tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')# 计算相似度with torch.no_grad():embeddings1 = model.bert(**inputs1.to('cpu'))['last_hidden_state'][:, 0]embeddings2 = model.bert(**inputs2.to('cpu'))['last_hidden_state'][:, 0]similarity_score = cosine_similarity_torch(embeddings1, embeddings2).item()# 映射到[0, 5]评分范围(假设训练时有此步骤)normalized_similarity = (similarity_score + 1) * 2.5return normalized_similarity# 示例
text1 = "瑞典驻利比亚班加西领事馆发生汽车炸弹袭击,无人员伤亡"
text2 = "汽车炸弹击中瑞典驻班加西领事馆,无人受伤。"
similarity = calculate_similarity(text1, text2)
print(f"两个句子的相似度为:{similarity:.2f}")

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/718168.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

EasyRecovery易恢复2024免费文件数据恢复软件下载

一、软件概述 EasyRecovery易恢复中文文件数据恢复软件是一款专为中文用户设计的强大数据恢复工具。该软件致力于帮助用户从各种存储设备中恢复因各种原因丢失的中文文件&#xff0c;如文档、图片、视频、音频等。凭借其核心技术和多年的研发经验&#xff0c;EasyRecovery易恢…

STM32-SPI通信协议

串行外设接口SPI&#xff08;Serial Peripheral Interface&#xff09;是由Motorola公司开发的一种通用数据总线。 在某些芯片上&#xff0c;SPI接口可以配置为支持SPI协议或者支持I2S音频协议。 SPI接口默认工作在SPI方式&#xff0c;可以通过软件把功能从SPI模式切换…

【数据结构与算法】常见排序算法(Sorting Algorithm)

文章目录 相关概念1. 冒泡排序&#xff08;Bubble Sort&#xff09;2. 直接插入排序&#xff08;Insertion Sort&#xff09;3. 希尔排序&#xff08;Shell Sort&#xff09;4. 直接选择排序&#xff08;Selection Sort&#xff09;5. 堆排序&#xff08;Heap Sort&#xff09;…

【脑科学相关合集】有关脑影像数据相关介绍的笔记及有关脑网络的笔记合集

【脑科学相关合集】有关脑影像数据相关介绍的笔记及有关脑网络的笔记合集 前言脑模板方面相关笔记清单 基于脑网络的方法方面数据基本方面 前言 这里&#xff0c;我将展开有关我自己关于脑影像数据相关介绍的笔记及有关脑网络的笔记合集。其中&#xff0c;脑网络的相关论文主要…

TOMCAT的安装与基本信息

一、TOMCAT简介 Tomcat 服务器是一个免费的开放源代码的Web 应用服务器&#xff0c;属于轻量级应用服务器&#xff0c;在中小型系统和并发访问用户不是很多的场合下被普遍使用&#xff0c;是开发和调试JSP 程序的首选。对于一个初学者来说&#xff0c;可以这样认为&#xff0c…

IO 与 NIO

优质博文&#xff1a;IT-BLOG-CN 一、阻塞IO / 非阻塞NIO 阻塞IO&#xff1a;当一条线程执行read()或者write()方法时&#xff0c;这条线程会一直阻塞直到读取到了一些数据或者要写出去的数据已经全部写出&#xff0c;在这期间这条线程不能做任何其他的事情。 非阻塞NIO&…

记录踩过的坑-macOS下使用VS Code

目录 切换主题 安装插件 搭建Python开发环境 装Python插件 配置解释器 打开项目 打开终端 切换主题 安装插件 方法1 方法2 搭建Python开发环境 装Python插件 配置解释器 假设解释器已经通过Anaconda建好&#xff0c;只需要在VS Code中关联。 打开项目 打开终端

ArmV8架构

Armv8/armv9架构入门指南 — Armv8/armv9架构入门指南 v1.0 documentation 上面只是给了一个比较好的参考文档 其他内容待补充

AutoSAR(基础入门篇)13.5-Mcal Mcu时钟的配置

目录 一、EB的Mcu模块结构 二、时钟的配置 对Mcu的配置主要就是其时钟树的配置,但是EB将GTM、ERU等很多重要的模块也都放在了Mcu里面做配置,所以这里的Mcu是一个很庞大的模块, 我们目前只讲时钟树的部分 一、EB的Mcu模块结构 1. 所有的模块都基本上是这么些配置类别,Mc…

阅读笔记 | Transformers in Time Series: A Survey

阅读论文&#xff1a; Wen, Qingsong, et al. “Transformers in time series: A survey.” arXiv preprint arXiv:2202.07125 (2022). 这篇综述主要对基于Transformer的时序建模方法进行介绍。论文首先简单介绍了Transformer的基本原理&#xff0c;包括位置编码、多头注意力机…

回归预测 | Matlab实现RIME-BP霜冰算法优化BP神经网络多变量回归预测

回归预测 | Matlab实现RIME-BP霜冰算法优化BP神经网络多变量回归预测 目录 回归预测 | Matlab实现RIME-BP霜冰算法优化BP神经网络多变量回归预测预测效果基本描述程序设计参考资料 预测效果 基本描述 1.Matlab实现RIME-BP霜冰算法优化BP神经网络多变量回归预测&#xff08;完整…

自动化测试介绍、selenium用法(自动化测试框架+爬虫可用)

文章目录 一、自动化测试1、什么是自动化测试&#xff1f;2、手工测试 vs 自动化测试3、自动化测试常见误区4、自动化测试的优劣5、自动化测试分层6、什么项目适合自动化测试 二、Selenuim1、小例子2、用法3、页面操作获取输入内容模拟点击清空文本元素拖拽frame切换窗口切换/标…

十五 超级数据查看器 讲解稿 外观设置

十五 超级数据查看器 讲解稿 外观设置 视频讲座地址 讲解稿全文: 大家好&#xff0c;今天讲解超级数据查看器,详情界面的外观设置。 首先&#xff0c;我们打开超级数据查看器。 本节课以成语词典为例来做讲述。 我们打开成语词典这个表&#xff0c;随便选一条记录点击&#x…

【虚拟机安装centos7后找不到网卡问题】

最近开始学习linux&#xff0c;看着传智播客的教学视频学习&#xff0c;里面老师用的是centos6.5&#xff0c;我这边装的是centos7最新版的 结果到了网络配置的这一节&#xff0c;卡了我好久。 我在centos一直找不到我的网卡eth0&#xff0c;只有一个回环网口&#xff0c;在/…

第五套CCF信息学奥赛c++练习题 CSP-J认证初级组 中小学信奥赛入门组初赛考前模拟冲刺题(选择题)

第五套中小学信息学奥赛CSP-J考前冲刺题 1、不同类型的存储器组成了多层次结构的存储器体系,按存取速度从快到慢排列的是 A、快存/辅存/主存 B、外存/主存/辅存 C、快存/主存/辅存 D、主存/辅存/外存 答案&#xff1a;C 考点分析&#xff1a;主要考查计算机相关知识&…

静态链表(3)

尾插函数 尾插就比头插多了一步找尾巴&#xff0c;其他均一样 尾插步骤画图 1.找到空闲结点3 2.空链踢空点&#xff0c;穿透删除 先绑后面 再接前面&#xff0c;就完成插入了 综上所述&#xff0c;静态链表就是处理两条链表&#xff0c;静态链表总的执行一次插入或删除&#…

【大厂AI课学习笔记NO.62】模型的部署

我们历尽千辛万苦&#xff0c;总算要部署模型了。这个系列也写到62篇&#xff0c;不要着急&#xff0c;后面还有很多。 这周偷懒了&#xff0c;一天放出太多的文章&#xff0c;大家可能有些吃不消&#xff0c;从下周开始&#xff0c;本系列将正常更新。 这套大厂AI课&#xf…

【剑指offer--C/C++】JZ3 数组中重复的数字

一、题目 二、本人思路及代码 这道题目它要求的时间空间利用率都是n&#xff0c;那么可以考虑创建一个长度为n的数组repeat初始化为0&#xff0c;下标代码出现的数字&#xff0c;下标对应的数组内容代表该下标数字出现的次数。然后遍历提供的数组&#xff0c;每出现一个数字&a…

超详细多表查询详解-多表关系-多表查询-子查询

多表关系 一对多关系&#xff1a;这是最常见的关系类型&#xff0c;它表示在两个表之间&#xff0c;一个表中的记录可以与另一个表中的多个记录相关联。例如&#xff0c;一个班级&#xff08;父表&#xff09;可以有多个学生&#xff08;子表&#xff09;&#xff0c;但每个学…

市场复盘总结 20240301

仅用于记录当天的市场情况&#xff0c;用于统计交易策略的适用情况&#xff0c;以便程序回测 短线核心&#xff1a;不参与任何级别的调整&#xff0c;采用龙空龙模式 一支股票 10%的时候可以操作&#xff0c; 90%的时间适合空仓等待 二进三&#xff1a; 进级率中 40% 最常用的…