bert 相似度任务训练简单版本,faiss 寻找相似 topk

目录

任务

代码

train.py

predit.py

faiss 最相似的 topk


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5# 准备数据集
class MyDataset(Dataset):def __init__(self, data_file_paths):self.texts = []self.labels = []# 分词器用默认的self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')# 自己实现对数据集的解析with open(data_file_paths, 'r', encoding='utf-8') as f:for line in f:text1, text2, label = line.split('\t')self.texts.append((text1, text2))self.labels.append(int(label))def __len__(self):return len(self.texts)def __getitem__(self, idx):text1, text2 = self.texts[idx]label = self.labels[idx]encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')return encoded_text, label# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0for epoch in range(epoch):trained_data = 0for batch in train_loader:inputs, labels = batch# 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉inputs['input_ids'] = inputs['input_ids'].squeeze(1)inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)# 因为要用GPU,将数据传输到gpu上inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(**inputs, labels=labels)loss, logits = outputs[:2]loss.backward()optimizer.step()trained_data += len(labels)trained_process = float(trained_data) / len(train_dataset)batch_after_last_save += 1total_batch += 1# 每训练 auto_save_batch 个 batch,保存一次模型if batch_after_last_save >= auto_save_batch:batch_after_last_save = 0model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))total_epoch += 1model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePredictiontokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')with torch.no_grad():with open('../data/test.txt', 'r', encoding='utf8') as f:lines = f.readlines()correct = 0for i, line in enumerate(lines):text1, text2, label = line.split('\t')encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')outputs = model(**encoded_text)res = torch.argmax(outputs.logits, dim=1).item()print(text1, text2, label, res)if str(res) == label.strip('\n'):correct += 1print(f'{i + 1}/{len(lines)}')print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

faiss 最相似的 topk

使用 faiss 寻找 topk 相似的,从结果上看最相似的基本都还是找到排到较为靠前的位置

import torch
import faiss
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel# 假设有一个数据集df,其中包含'index'列和'text'列
df = pd.read_csv('../data/DataAnalyst.csv', encoding='gbk')  # 根据实际情况加载数据集
df = df.dropna().drop_duplicates().reset_index()
df['index'] = df.index
df = df[['index', '公司所在商区']]  # 保留所需列
df['公司所在商区'] = df['公司所在商区'].map(lambda row: ','.join(eval(row)))# device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')# 加载微调好的模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertModel.from_pretrained('../output/cn_equal_model_3_171.pth')
model.eval()# 将数据集转化为模型所需的格式并计算所有样本的向量表示
def encode_texts(df):text_vectors = []for index, row in df.iterrows():text = row['公司所在商区']inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')with torch.no_grad():embeddings = model(**inputs.to(device))['last_hidden_state'][:, 0]text_vectors.append(embeddings.cpu().numpy())print(f'{index + 1}/{len(df)}')return np.vstack(text_vectors)# 加载数据集并计算所有样本的向量
print('enbedding all data...')
all_embeddings = encode_texts(df)# 初始化Faiss索引
print('init faiss all embedding...')
index = faiss.IndexFlatIP(all_embeddings.shape[1])  # 使用内积空间,适用于余弦相似度
index.add(all_embeddings)
print('init faiss all embedding finish~~~')# 定义查找最相似样本的函数
def find_top_k_similar(query_text, k=100):print('当前 query_text embedding.')query_embedding = encode_single_text(query_text)print('begin to search topk....')D, I = index.search(query_embedding, k)  # 返回距离和索引top_k_indices = df.iloc[I[0]].index.tolist()  # 将索引转换为原始数据集的索引return top_k_indices# 编码单个文本的函数
def encode_single_text(text):inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')with torch.no_grad():embedding = model(**inputs.to(device))['last_hidden_state'][:, 0].cpu().numpy()print('当前 query_text embedding finish!')return embedding# 示例:找一个query_text的top10相似样本
query_text = "左家庄,国展,西坝河"
top10_indices = find_top_k_similar(query_text)
# 获取与查询文本最相似的前10条原始文本
top10_texts = [df.loc[index, '公司所在商区'] for index in top10_indices]print(f"与'{query_text}'最相似的前100条样本及其文本:")
for i, (idx, text) in enumerate(zip(top10_indices, top10_texts)):print(f"{i+1}. 索引:{idx},文本:{text}")

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
链接:https://pan.baidu.com/s/1CTntG1Z6AIhiPt6i8Ad97Q 
提取码:x6sz 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/720721.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Tomcat概念、安装及相关文件介绍

目录 一、web技术 1、C/S架构与B/S架构 1.1 http协议与C/S架构 1.2 http协议与B/S架构 2、前端三大核心技术 2.1 HTML(Hypertext Markup Language) 2.2 css(Cascading Style Sheets) 2.3 JavaScript 3、同步和异步 4、…

Node.js与Webpack笔记(一)

这里使用的16.19.0版本,官网和github没找到,去黑马2023年课程里找 篇幅较大会卡,此篇幅不写Webpack部分,留着下一篇 初识 1.什么是Node.js? Node.js 是一个独立的 JavaScript 运行环境,能独立执行 JS 代码&#xff…

【Linux】Linux原生异步IO:AIO

1、IO模型 1.1 简述 相信大家在搜索的时候,都会看到下面这张图,IO的使用场景:同步、异步、阻塞、非阻塞,可以组合成四种情况: 同步阻塞I/O: 用户进程进行I/O操作,一直阻塞到I/O操作完成为止。同步非阻塞I/O: 用户程序可以通过设置文件描述符的属性O_NONBLOCK,I/O操作可…

向微队列添加任务的四种方式

向微队列添加任务的四种方式 关于微任务,微队列,事件循环,可参考:深入:微任务与 Javascript 运行时环境 - Web API 接口参考 | MDN (mozilla.org) 先说答案, 四种方法: Promise.resolve().then();Mutation…

【Web前端入门学习】——HTML

目录 HTML简介HTML文件结构常用文本标签标题标签段落标签有序列表和无序列表表格标签 HTML属性a标签—超链接标签图片标签 HTML区块块元素与行内元素 HTML表单 HTML简介 HTML全称是Hypertext Markup Language超文本标记语言。 HTML的作用: 为网页提供结构&#xff…

数据库管理-第158期 Oracle Vector DB AI-09(20240304)

数据库管理158期 2024-03-04 数据库管理-第158期 Oracle Vector DB & AI-09(20240304)1 创建示例表2 添加过滤条件的向量近似查询示例1示例2示例3示例4示例5示例6示例7 总结 数据库管理-第158期 Oracle Vector DB & AI-09(20240304&a…

C#插入排序算法

插入排序实现原理 插入排序算法是一种简单、直观的排序算法,其原理是将一个待排序的元素逐个地插入到已经排好序的部分中。 具体实现步骤如下 首先咱们假设数组长度为n,从第二个元素开始,将当前元素存储在临时变量temp中。 从当前元素的前一…

iOS 17.0 UIGraphicsBeginImageContextWithOptions 崩溃处理

在升级到iOS17后你会发现,之前版本运行的很好,这个版本突然会出现一个运行闪退。报错日志为*** Assertion failure in void _UIGraphicsBeginImageContextWithOptions(CGSize, BOOL, CGFloat, BOOL)(), UIGraphics.m:410 跟踪到具体的报错位置如下所示&a…

第4章 HSA运行时

HSA运行时是一种精简的用户模式应用程序编程接口API,它提供了主机将计算内核启动到可用HSA代理程序所必须的接口。它可以分为两类:核心和扩展。HSA核心运行时API旨在支持HSA系统平台体系结构规范所需的操作,并且必须得到任何符合HSA的系统的支…

Java多线程导入Excel示例

在导入Excel的时候,如果文件比较大,行数很多,一行行读往往速度比较慢,为了加快导入速度,我们可以采用多线程的方式 话不多说直接上代码 首先是Controller import com.sakura.base.service.ExcelService; import com.s…

智慧城市中的数字孪生:数字孪生技术助力智慧城市提高公共服务水平

目录 一、引言 二、数字孪生技术概述 三、数字孪生技术在智慧城市中的应用 1、智慧交通管理 2、智慧能源管理 3、智慧环保管理 4、智慧公共安全 四、数字孪生技术助力智慧城市提高公共服务水平的价值 五、挑战与前景 六、结论 一、引言 随着信息技术的飞速发展&…

【LeetCode】升级打怪之路 Day 13:优先级队列的应用

今日题目: 23. 合并 K 个升序链表 | LeetCode378. 有序矩阵中第 K 小的元素 | LeetCode373. 查找和最小的 K 对数字 | LeetCode703. 数据流中的第 K 大元素 | LeetCode347. 前 K 个高频元素 | LeetCode 目录 Problem 1:合并多个有序链表 【classic】LC 2…

【蓝牙协议栈】【BR/EDR】【AVDTP】音视频分布传输协议

1. AVDTP概念 AVDTP即 AUDIO/VIDEO DISTRIBUTION TRANSPORT PROTOCOL(音视频分配传输协议),主要负责 A/V stream的协商、建立及传输程序,还指定了设备之前传输A/V stream的消息格式. AVDTP的传输机制和消息格式是以 RTP为基础的。RTP由 RTP Data Transfer Protocol (RTP)和…

【软考高项】【计算专题】- 5 - 进度类 - 横道图/甘特图

一、知识点 1、基本定义 甘特图(Gantt chart )又称为横道图、条状图(Bar chart),通过条状图来显示项目各活动的进 度情况。以提出者亨利劳伦斯甘特( Henry Laurence Gantt)先生的名字命名。 目前许多文档工具都可以画甘特图。 (1)我的举例 …

07. Nginx进阶-Nginx负载均衡

简介 负载均衡 什么是负载均衡? 负载均衡,英文名称为Load Balance,其含义就是指将负载(工作任务)进行平衡、分摊到多个操作单元上进行运行。 Nginx负载均衡 什么是Nginx负载均衡? Nginx负载均衡可以大…

计算机网络-典型网络组网架构

前面基本网络知识已经能够满足中小企业的需要了,今天来看下一些基本网络组网架构。 首先网络是分层架构,从接入层到汇聚层再到核心层,然后接入运营商出口。内部包括有线网络、无线网络,出口一般可以使用路由器或者防火墙进行安全防…

StarRocks实战——vivo基于 StarRocks 构建实时大数据平台

目录 前言 一、数据挑战 1.1 时效性挑战,业务分析决策需加速 1.2 访问量挑战,性能与稳定性亟待提高,支撑业务稳定运行 1.3 计算场景挑战,难以满足业务复杂查询需求 1.4. 运维挑战,用户查询体验需优化 二、OLA…

WebDAV之π-Disk派盘+人生Life

人生Life是一款日程软件,在这款待办的日程软件当中各种功能极为的完善,完全的足够用户在日常当中的使用,你的待办方面的各种内容都能够在软件上面进行规划和填充,通过待办事项来帮助用户提高在日常当中的效率,对于用户来说这款待办事项的软件是绝佳的选择。 π-Disk派盘 …

java面试(jvm)

JVM内存模型 细分Eden: java类加载过程?双亲委派机制?一个对象从加载到JVM,再到被GC清除过程? JAVA类加载器:AppClassLoader - ExtClassLoader - BootStrapClassLoader。每种类加载器都有他自己的加载目录…

2024年聚合工艺证模拟考试题库及聚合工艺理论考试试题

题库来源:安全生产模拟考试一点通公众号小程序 2024年聚合工艺证模拟考试题库及聚合工艺理论考试试题是由安全生产模拟考试一点通提供,聚合工艺证模拟考试题库是根据聚合工艺最新版教材,聚合工艺大纲整理而成(含2024年聚合工艺证…