利用sentence bert 实现语义向量搜索

目录

基于pytorch的中文语言模型预训练:https://github.com/zhusleep/pytorch_chinese_lm_pretrain/tree/master

sentence_emb.py

search_faiss_robert768.py

faiss_index.py

gen_vec_save2_faiss.py


基于pytorch的中文语言模型预训练:https://github.com/zhusleep/pytorch_chinese_lm_pretrain/tree/master

sentence_emb.py

#from transformers import BertTokenizer, BertModel
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#
## First we initialize our model and tokenizer:
#tokenizer = BertTokenizer.from_pretrained('./result')
#model = BertModel.from_pretrained('./result')def split_batch(init_list, batch_size):groups = zip(*(iter(init_list),) * batch_size)end_list = [list(i) for i in groups]count = len(init_list) % batch_sizeend_list.append(init_list[-count:]) if count != 0 else end_listreturn end_list"""
param: sentence list
return: embeddings
"""
def encode(sentences, tokenizer, model):tokens = {'input_ids': [], 'attention_mask': []}data_num = len(sentences)for sentence in sentences:# 编码每个句子并添加到字典new_tokens = tokenizer.encode_plus(str(sentence), max_length=128,truncation=True, padding='max_length',return_tensors='pt')tokens['input_ids'].append(new_tokens['input_ids'][0])tokens['attention_mask'].append(new_tokens['attention_mask'][0])# 将张量列表重新格式化为一个张量tokens['input_ids'] = torch.stack(tokens['input_ids']).to(device)tokens['attention_mask'] = torch.stack(tokens['attention_mask']).to(device)model.eval()# We process these tokens through our model:with torch.no_grad():#添加这行代码outputs = model(**tokens)# odict_keys(['last_hidden_state', 'pooler_output'])# The dense vector representations of our text are contained within the outputs 'last_hidden_state' tensor, which we access like so:embeddings = outputs[0]# To perform this operation, we first resize our attention_mask tensor:attention_mask = tokens['attention_mask']# attention_mask.shapemask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()# mask.shape# 上面的每个向量表示一个单独token的掩码现在每个token都有一个大小为768的向量,表示它的attention_mask状态。然后将两个张量相乘:masked_embeddings = embeddings * mask# masked_embeddings.shape# torch.Size([2, 128, 768])torch.Size([data_num, 128, 768])summed = torch.sum(masked_embeddings, 1)summed_mask = torch.clamp(mask.sum(1), min=1e-9)mean_pooled = summed / summed_mask# print(mean_pooled)# print(type(mean_pooled))return mean_pooled#sentences = [
#    "你叫什么名字?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#]
#sb = split_batch(sentences, 2)
#embs = []
#for batch in sb:
#	emb = encode(batch)
#	embs += emb
#
#print(embs)
#print(len(embs))

search_faiss_robert768.py

import pickle
from faiss_index import faissIndex
import pandas as pd
import numpy as np
# from sentence_transformers import SentenceTransformer
# Download model
# model = SentenceTransformer('paraphrase-MiniLM-L6-v2/')
from sentence_emb import encodefrom transformers import BertTokenizer, BertModel
import torch
# First we initialize our model and tokenizer:
tokenizer = BertTokenizer.from_pretrained('./result')
model = BertModel.from_pretrained('./result').cuda()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# faiss_index_path = "faiss_index384.pkl"
faiss_index_path = "faiss_index_robert.pkl"symptom_name_df = pd.read_csv("col2.csv")# 从本地加载faiss_index模型
def load_faiss_index(var_faiss_model_path):# 从本地加载faiss_index模型# with open('strategy/semantic_recall/model/tt.txt', 'r') as f:#     print(f.readlines())with open(var_faiss_model_path, mode='rb', errors=None) as fr:index = pickle.load(fr, encoding='ASCII', errors='ASCII')return indexdef symptom_name_recall(symptom_name):# 将参数中当前的文本编码成向量sentence = []sentence.append(symptom_name)# qyery_emb = model.encode(sentence)qyery_emb = encode(sentence,tokenizer,model)# 去faiss中检索相近的faiss索引# 加载faissloaded_faiss_index = load_faiss_index(faiss_index_path)# 寻找最近k个物料# R, D, I = loaded_faiss_index.search_items(qyery_emb.reshape([-1, 384]), k=10, n_probe=5)R, D, I = loaded_faiss_index.search_items(np.array(qyery_emb.reshape([-1, 768]).cpu()), k=10, n_probe=5)# 从faiss库中检索的物料ID进行转换result = []for id_list in R:for item in id_list:result.append(item)symptom_name_list = symptom_name_df[symptom_name_df['index'].isin(result)]['symptom_name'].to_list()# 从相似度检索的结果中,去除自己if symptom_name in symptom_name_list:symptom_name_list.remove(symptom_name)print(symptom_name + ' 的相近的词:' + str(symptom_name_list))word_lsit = ['头痛','恶心吧吐','期饮酒','出血','失眠']
for word in word_lsit:symptom_name_recall(word)

faiss_index.py

import faiss
import numpy as npclass faissIndex:def __init__(self, dim, n_centroids, metric):self.dim = dimself.n_centriods = n_centroidsassert metric in ('INNER_PRODUCT', 'L2'), "Input metric not in 'INNER_PRODUCT' or 'L2'"self.metric = faiss.METRIC_INNER_PRODUCT if metric == 'INNER_PRODUCT' else faiss.METRIC_L2self._build_index()returndef _build_index(self):self._quantizer = faiss.IndexFlatL2(self.dim)self.index = faiss.IndexIVFFlat(self._quantizer, self.dim, self.n_centriods, self.metric)self.is_trained = self.index.is_trainedself.n_samples = 0  # 查询向量池中的向量个数self.items = np.array([])  # 向量池中向量对应的item,数量应与self.n_samples保持一致,即向量与item一一对应return Truedef reset_index(self, dim, n_centroids, metric):self.dim = dimself.n_centriods = n_centroidsassert metric in ('INNER_PRODUCT', 'L2'), "Input metric not in 'INNER_PRODUCT' or 'L2'"self.metric = faiss.METRIC_INNER_PRODUCT if metric == 'INNER_PRODUCT' else faiss.METRIC_L2self._build_index()returndef train(self, vectors_train):self.index.train(vectors_train)self.is_trained = self.index.is_trainedreturndef add(self, vectors, items=None):if not items.empty:  # 当有输入items时,验证之前的item和vector数量是否匹配,以及当前输入assert len(vectors) == len(items), "Length of vectors ({n_vectors}) and items ({n_items}) don't match, please check your input.".format(n_vectors=len(vectors), n_items=len(items))assert self.n_samples == len(self.items), "Amounts of added vectors and items don't match, cannot add more items."self.items = np.append(self.items, items.to_numpy())else:assert len(self.items) == 0, "There were items added previously, please added corresponding items in this batch."self.index.add(vectors)self.n_samples += len(vectors)returndef search(self, query_vector, k, n_probe=1):assert query_vector.shape[1] == self.dim, "The dimension of query vector ({dim_vector}) doesn't match the training vector set ({dim_index})!".format(dim_vector=query_vector.shape[1], dim_index=self.dim)assert self.is_trained, "Faiss index is not trained, please train index first!"assert self.n_samples > 0, "Faiss index doesn't have any vector for query, please add vectors into index first!"self.index.nprobe = n_probeD, I = self.index.search(query_vector, k)return D, I# k = 30 # 对每条向量(每行)寻找最近k个物料# n_probe = 5 # 每次查询只查询最近邻n_probe个聚类def search_items(self, query_vector, k, n_probe=1):D, I = self.search(query_vector, k, n_probe)R = [self.items[i] for i in I]return R, D, I

gen_vec_save2_faiss.py

"""
# 训练语义向量并保存在faiss中
step1: 将句子生成向量
step2: 将向量保存在faiss中
"""
import pandas as pd
import numpy as np
# from sentence_transformers import SentenceTransformer
# Download model
# model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
from sentence_emb import encode
import pickle
from faiss_index import faissIndex
from tqdm import tqdmfaiss_index_path = "faiss_index_robert.pkl"from transformers import BertTokenizer, BertModel
import torchdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# First we initialize our model and tokenizer:
tokenizer = BertTokenizer.from_pretrained('./result')
model = BertModel.from_pretrained('./result').cuda()# ====================== 创建faiss index并进行训练 ======================
# 创建faiss index并进行训练
def build_faiss_index(df_resources, semantic_vector, n_centroids=5, metric='L2'):print("现在开始进行faiss index模型训练")# 构建faiss索引模型dim = semantic_vector.shape[1]print("训练数据维度:", dim)print("聚类中心个数:", n_centroids)print("向量距离指标:", metric)# 训练faiss索引index = faissIndex(dim, n_centroids, metric)# vectors = np.stack(df_resources['index'].values).astype('float32') # faiss只支持32位浮点数查询vectors = semantic_vectoritems = df_resources['index']index.train(vectors)index.add(vectors, items)print("faiss index模型已训练完成")return index# ====================== 保存faiss ======================
# 将index按照指定的日期命名并保存至本地
def save_index(index, path):print("现在开始将faiss index保存至本地")fw = open(path, mode='wb', errors=None)pickle.dump(index, fw)fw.close()print("faiss_index模型已保存至本地")def split_batch(init_list, batch_size):groups = zip(*(iter(init_list),) * batch_size)end_list = [list(i) for i in groups]count = len(init_list) % batch_sizeend_list.append(init_list[-count:]) if count != 0 else end_listreturn end_list"""
# 利用sentence transfermer 生成文本向量
# 训练faiss
# 保存faiss
param: 
"""def sentence2faiss_transfermer():df = pd.read_csv('col2.csv')train_json = df.to_dict('records')# 取文本将文本转化为向量title_list = [item['symptom_name'] for item in train_json]print(len(title_list))print("正在训练中.......")# title_list = title_list[:500]sb = split_batch(title_list, 8)embeddings = []# print(len(title_list))# emb = encode(title_list, tokenizer, model)# print(emb)# exit()for batch in tqdm(sb):try:emb = encode(batch, tokenizer, model)emb = np.array(emb.to("cpu"))for item in emb:embeddings.append(item)except Exception as e:print(e)# print(len(embeddings))# embeddings = np.array(embeddings)
#    print(embeddings)
#    print(len(embeddings))# exit()# embeddings = encode(title_list)# 创建faiss index并进行训练df_resources = pd.DataFrame(train_json)# print(embeddings.shape)print("==================================================")# emb = emb.cpu()# semantic_2d_array = np.array(embeddings.to("cpu"))# 将numpy数组转换成CUDA张量# semantic_2d_array= torch.tensor([item.cpu().detach().numpy() for item in semantic_2d_array]).cuda()print("开始build_faiss_index")# print(len(np.array(emb)))trained_index = build_faiss_index(df_resources, np.array(embeddings), n_centroids=5, metric='L2')print("开始save_index")# 保存faiss模型save_index(trained_index, faiss_index_path)sentence2faiss_transfermer()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/14513.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[协议]stm32读取AHT20程序示例

AHT20温度传感器使用程序&#xff1a; 使用i2c读取温度传感器数据很简单&#xff0c;但市面上有至少两个手册&#xff0c;我这个对应的手册贴出来&#xff1a; main: #include "stm32f10x.h" // Device header #include <stdint.h> #includ…

数智赋能内涝治理,四信城市排水防涝解决方案保障城市安全运行

由强降雨、台风造成城市低洼处出现大量积水、内涝的情况时有发生&#xff0c;给人们出行带来了极大不便和安全隐患&#xff0c;甚至危及群众生命财产安全。 为降低内涝造成的损失&#xff0c;一方面我们要大力加强城市排水基础设施的建设&#xff1b;另一方面要全面掌握城市内涝…

U-Boot menu菜单分析

文章目录 前言目标环境背景U-Boot如何自动调起菜单U-Boot添加自定义命令实践 前言 在某个厂家的开发板中&#xff0c;在进入它的U-Boot后&#xff0c;会自动弹出一个菜单页面&#xff0c;输入对应的选项就会执行对应的功能。如SD卡镜像更新、显示设置等&#xff1a; 目标 本…

docker命令详解大全

Docker是一种流行的容器化平台&#xff0c;用于快速部署应用程序并管理容器的生命周期。以下是一些常用的Docker命令及其用途的概述&#xff1a; docker run&#xff1a;创建一个新容器并运行一个命令。docker ps&#xff1a;列出当前运行的容器。docker stop&#xff1a;停止…

Unity射击游戏开发教程:(20)增加护盾强度

在本文中,我们将增强护盾,使其在受到超过 1 次攻击后才会被禁用。 Player 脚本具有 Shield PowerUp 方法,我们需要调整盾牌在被摧毁之前可以承受的数量,因此我们将声明一个 int 变量来设置盾牌可以承受的击中数量。

微信小程序画布显示图片绘制矩形选区

wxml <view class"page-body"><!-- 画布 --><view class"page-body-wrapper"><canvas canvas-id"myCanvas" type"2d" id"myCanvas" classmyCanvas bindtouchstart"touchStart" bindtouchmo…

OpenFeign快速入门 替代RestTemplate

1.引入依赖 <!--openFeign--><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-openfeign</artifactId></dependency><!--负载均衡器--><dependency><groupId>org.spr…

【全网最全】2024电工杯数学建模B题问题一14页论文+19建模过程代码+py代码+2种保奖思路+数据等(后续会更新成品论文等)

您的点赞收藏是我继续更新的最大动力&#xff01; 一定要点击如下的卡片链接&#xff0c;那是获取资料的入口&#xff01; 【全网最全】2024电工杯数学建模B题问一论文19建模过程代码py代码2种保奖思路数据等&#xff08;后续会更新成品论文等&#xff09;「首先来看看目前已…

C++中的四种类型转换运算符

隐式类型转换是安全的&#xff0c;显式类型转换是有风险的&#xff0c;C语言之所以增加强制类型转换的语法&#xff0c;就是为了强调风险&#xff0c;让程序员意识到自己在做什么。但是&#xff0c;这种强调风险的方式还是比较粗放&#xff0c;粒度比较大&#xff0c;它并没有表…

MySQL中如何知道数据库表中所有表的字段的排序规则是什么?

查看所有表的字段及其排序规则&#xff1a; 你可以查询 information_schema 数据库中的 COLUMNS 表&#xff0c;来获取所有表的字段及其排序规则。以下是一个示例查询&#xff1a; SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, COLLATION_NAME FROM information_schema.COL…

【设计模式深度剖析】【5】【创建型】【原型模式】| 类比群发邮件,加深理解

&#x1f448;️上一篇:建造者模式 | 下一篇:创建型设计模式对比&#x1f449;️ 目录 原型模式(Prototype Pattern)概览定义英文原话直译 3个角色类图1. 抽象原型&#xff08;Prototype&#xff09;角色2. 具体原型&#xff08;Concrete Prototype&#xff09;角色3. 客户…

必示科技参与智能运维国家标准预研线下编写会议并做主题分享

近日&#xff0c;《信息技术服务 智能运维 第3部分&#xff1a;算法治理》&#xff08;拟定名&#xff09;国家标准预研阶段第一次编写工作会议在杭州举行。本次会议由浙商证券承办。 此次编写有来自银行、证券、保险、通信、高校研究机构、互联网以及技术方等29家单位&#xf…

在云计算环境中,如何实现资源的高效分配和调度?

在云计算环境中&#xff0c;可以通过以下几种方法实现资源的高效分配和调度&#xff1a; 负载均衡&#xff1a;通过负载均衡算法&#xff0c;将云计算集群的负载均匀地分配到各个节点上。常见的负载均衡算法有轮询、最小连接数、最短响应时间等。 资源调度算法&#xff1a;为了…

Linux基础(四):Linux系统文件类型与文件权限

各位看官&#xff0c;好久不见&#xff0c;在正式介绍Linux的基本命令之前&#xff0c;我们首先了解一下&#xff0c;关于文件的知识。 目录 一、文件类型 二、文件权限 2.1 文件访问者的分类 2.2 文件权限 2.2.1 文件的基本权限 2.2.2 文件权限值的表示方法 三、修改文…

CSS3 新增背景属性 + 新增边框属性(如果想知道CSS3新增背景属性和新增边框属性的知识点,那么只看这一篇就够了!)

前言&#xff1a;CSS3在CSS2的基础上&#xff0c;新增了很多强大的新功能&#xff0c;从而解决一些实际面临的问题&#xff0c;本篇文章主要讲解的为CSS3新增背景属性和新增边框属性。 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要了解更多内容可以访问我的主页秋刀鱼不做梦-CSD…

视觉SLAM十四讲:从理论到实践(Chapter5:相机与图像)

前言 学习笔记&#xff0c;仅供学习&#xff0c;不做商用&#xff0c;如有侵权&#xff0c;联系我删除即可 目标 理解针孔相机的模型、内参与径向畸变参数。理解一个空间点是如何投影到相机成像平面的。掌握OpenCV的图像存储与表达方式。学会基本的摄像头标定方法。 一、相…

机器学习第四十周周报 WDN GGNN

文章目录 week40 WDN GGNN摘要Abstract一、文献阅读1. 题目2. abstract3. 网络架构3.1 问题提出3.2 GNN3.3 CSI GGNN 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过程4.3.1 数据获取4.3.2 参数设置4.3.3 实验结果 5. 结论二、GGNN1. 代码解释2. 网络结构小结参考文献参考文…

Vue 2 和 Vue 3 中同步和异步

Vue 2 和 Vue 3 中同步和异步 Vue 2 同步和异步 同步更新 (Synchronous Updates) Vue 2 在数据更新后会进行同步渲染更新,但为了性能优化,Vue 会在内部队列中异步地进行 DOM 更新。这意味着数据变化会立即被捕捉到,但实际的 DOM 更新会被推迟到下一个事件循环队列中。new V…

基础3 探索JAVA图形编程桌面:逻辑图形组件实现

在一个宽敞明亮的培训教室里&#xff0c;阳光透过窗户柔和地洒在地上&#xff0c;教室里摆放着整齐的桌椅。卧龙站在讲台上&#xff0c;面带微笑&#xff0c;手里拿着激光笔&#xff0c;他的眼神中充满了热情和期待。他的声音清晰而洪亮&#xff0c;传遍了整个教室&#xff1a;…

Linux模拟考试

注意&#xff0c;以下答案仅供参考 1、某CentOS系统空间不够&#xff0c;现加一块100G的硬盘(是系统的第二块硬盘&#xff09;&#xff0c;分为一个区99G&#xff0c;挂载点是/data&#xff0c;请写出从分区到挂载并使用的整个步骤及相关命令。 1.创建分区&#xff1a; sudo f…