大模型的 Embedding 模型该如何进行微调?

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学.

针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。

汇总合集:《大模型面试宝典》(2024版) 发布!


本文将会介绍如何使用 Sentence Transformers 对开源的Embedding模型bge-base-zh-v1.5进行微调,并验证 Embedding 模型微调后的效果。

在RAG框架或者语义相似度计算任务时,Embedding模型是我们常常会打交道的模型。

Sentence Transformers 是一个 Python 库,用于使用和训练各种应用的Embedding模型,例如检索增强生成 (RAG)、语义搜索、语义文本相似度、释义挖掘 (paraphrase mining) 等等。其 3.0 版本的更新是该工程自创建以来最大的一次,引入了一种新的训练方法。

本文将会以智源研究院(BAAI)开源的Embedding模型bge-base-zh-v1.5作为基准模型,展示如何使用Sentence Transformers进行评估,并对其进行微调,验证微调后的模型效果会有所提升。

评估指标Baseline

使用LlamaIndex框架对RAG流程中的各种Retrieve算法,包括Embedding模型召回,进行了评估,评估指标采用Hit RateMRR。本文将继续使用这篇文章中给出的数据集进行评估。

示例评估代码如下:

# -*- coding: utf-8 -*-
# @file: bge_base_zh_eval.py
import os
import json
import time
import torch
from pprint import pprint
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.util import cos_simproject_dir = os.path.dirname(os.path.abspath(__file__)).split('/src')[0]# data process
# load dataset, get corpus, queries, relevant_docs
with open(os.path.join(project_dir, "data/doc_qa.json"), "r", encoding="utf-8") as f:content = json.loads(f.read())corpus = content['corpus']
queries = content['queries']
relevant_docs = content['relevant_docs']# # Load a model
# 替换成自己的模型完整路径或使用huggingface modl id
model_name = "bge-base-zh-v1.5"
model_path = os.path.join(project_dir, f"models/{model_name}")
model = SentenceTransformer(model_path, device="cuda" if torch.cuda.is_available() else "cpu")
print("Model loaded")s_time = time.time()# # Evaluate the model
evaluator = InformationRetrievalEvaluator(queries=queries,corpus=corpus,relevant_docs=relevant_docs,name=f"{os.path.basename(model_path)}",score_functions={"cosine": cos_sim}
)# Evaluate the model
result = evaluator(model)
pprint(result)
print(f"Time cost: {time.time() - s_time:.2f}s")

我们在评估器中传入queries, corpus, relevant_docs字典,加载完模型后即可进行评估。

评估结果在下文中给出,作为baseline(基准)指标。

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了大模型算法岗技术与面试交流群, 想要交流、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:CSDN + 技术交流

微调数据合成

LlamaIndex框架中,可方便地使用generate_qa_embedding_pairs方法,利用Prompt工程对文本生成相关问题并进行关联。

Embedding模型的微调数据合成脚本如下:

# -*- coding: utf-8 -*-
# @file: make_ft_corpus.py
import os
from llama_index.legacy.finetuning import (generate_qa_embedding_pairs
)
from llama_index.llms.openai import OpenAI
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from dotenv import load_dotenvload_dotenv()project_dir = os.path.dirname(os.path.abspath(__file__)).split('/src')[0]TRAIN_FILES = [os.path.join(project_dir, "data/ft_train.txt")]
VAL_FILES = [os.path.join(project_dir, "data/ft_test.txt")]TRAIN_CORPUS_FPATH = os.path.join(project_dir, "data/ft_train_corpus.json")
VAL_CORPUS_FPATH = os.path.join(project_dir, "data/ft_val_corpus.json")def load_corpus(files, verbose=False):if verbose:print(f"Loading files {files}")reader = SimpleDirectoryReader(input_files=files)docs = reader.load_data()if verbose:print(f"Loaded {len(docs)} docs")parser = SentenceSplitter(chunk_size=250, chunk_overlap=0)nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)if verbose:print(f"Parsed {len(nodes)} nodes")return nodestrain_nodes = load_corpus(TRAIN_FILES, verbose=True)
val_nodes = load_corpus(VAL_FILES, verbose=True)llm = OpenAI(model="gpt-3.5-turbo", api_key=os.getenv("OPENAI_API_KEY"))qa_generate_prompt_tmpl = """\
Context information is below.---------------------
{context_str}
---------------------Given the context information and not prior knowledge.
generate only questions based on the below query.You are a Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination in Chinese. The questions should be diverse in nature \
across the document in Chinese. The questions should not contain options, not start with Q1/ Q2. \
Restrict the questions to the context information provided.
"""train_dataset = generate_qa_embedding_pairs(nodes=train_nodes, llm=llm, num_questions_per_chunk=1, qa_generate_prompt_tmpl=qa_generate_prompt_tmpl)
val_dataset = generate_qa_embedding_pairs(nodes=val_nodes, llm=llm, num_questions_per_chunk=1, qa_generate_prompt_tmpl=qa_generate_prompt_tmpl)train_dataset.save_json(TRAIN_CORPUS_FPATH)
val_dataset.save_json(VAL_CORPUS_FPATH)

输出结果如下:

Output:Loading files ['/Users/admin/PycharmProjects/embedding_model_exp/data/ft_train.txt']
Loaded 1 docs
Parsing nodes: 100%|██████████| 1/1 [00:00<00:00, 23.54it/s]
Parsing nodes:   0%|          | 0/1 [00:00<?, ?it/s]Parsed 137 nodes
Loading files ['/Users/admin/PycharmProjects/embedding_model_exp/data/ft_test.txt']
Loaded 1 docs
Parsing nodes: 100%|██████████| 1/1 [00:00<00:00, 45.84it/s]0%|          | 0/137 [00:00<?, ?it/s]Parsed 111 nodes
100%|██████████| 137/137 [03:34<00:00,  1.57s/it]
100%|██████████| 111/111 [01:55<00:00,  1.04s/it]

这样,我们就能得到微调数据集了,保存为ft_train_corpus.json和ft_val_corpus.json。

Embedding模型微调

接下来,我们将会对bge-base-zh-v1.5模型进行微调,微调的目的是让模型更适配我们自己的数据集,从而取得更好的召回效果。

使用 `sentence-transformers v3`

这里,我们使用的sentence-transformers模块的版本为V3.0.0。

利用该模块,我们不难实现Embedding模型微调,微调代码如下:

# -*- coding: utf-8 -*-
# @file: ft_sentence_transformers_trainer.py
import os
import json
import time
import torch
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.util import cos_sim
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers import SentenceTransformerTrainerstart_time = time.time()
project_dir = os.path.dirname(os.path.abspath(__file__)).split('/src')[0]# load eval dataset
with open(os.path.join(project_dir, "data/ft_val_dataset.json"), "r", encoding="utf-8") as f:eval_content = json.loads(f.read())corpus, queries, relevant_docs = eval_content['corpus'], eval_content['queries'], eval_content['relevant_docs']
# load train dataset
with open(os.path.join(project_dir, "data/ft_train_dataset.json"), "r", encoding="utf-8") as f:train_content = json.loads(f.read())train_anchor, train_positive = [], []
for query_id, context_id in train_content['relevant_docs'].items():train_anchor.append(train_content['queries'][query_id])train_positive.append(train_content['corpus'][context_id[0]])train_dataset = Dataset.from_dict({"positive": train_positive, "anchor": train_anchor})print(train_dataset)
print(train_dataset[0:5])# Load a model
model_name = 'bge-base-zh-v1.5'
# 替换成自己的模型完整路径或使用huggingface modl id
model_path = os.path.join(project_dir, f"models/{model_name}")
model = SentenceTransformer(model_path, device="cuda:0" if torch.cuda.is_available() else "cpu")
print("Model loaded")# # Evaluate the model
evaluator = InformationRetrievalEvaluator(queries=queries,corpus=corpus,relevant_docs=relevant_docs,name=f"{model_name}",score_functions={"cosine": cos_sim}
)
train_loss = MultipleNegativesRankingLoss(model)# define training arguments
args = SentenceTransformerTrainingArguments(output_dir=f"ft_{model_name}",  # output directory and hugging face model IDnum_train_epochs=5,  # number of epochsper_device_train_batch_size=2,  # train batch sizegradient_accumulation_steps=2,  # for a global batch size of 512per_device_eval_batch_size=4,  # evaluation batch sizewarmup_ratio=0.1,  # warmup ratiolearning_rate=2e-5,  # learning rate, 2e-5 is a good valuelr_scheduler_type="cosine",  # use constant learning rate scheduleroptim="adamw_torch_fused",  # use fused adamw optimizertf32=True,  # use tf32 precisionbf16=True,  # use bf16 precisionbatch_sampler=BatchSamplers.NO_DUPLICATES,eval_strategy="epoch",  # evaluate after each epochsave_strategy="epoch",  # save after each epochlogging_steps=10,  # log every 10 stepssave_total_limit=3,  # save only the last 3 modelsload_best_model_at_end=True,  # load the best model when training endsmetric_for_best_model=f"eval_{model_name}_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score
)# train the model
trainer = SentenceTransformerTrainer(model=model,    # the model to trainargs=args,      # training argumentstrain_dataset=train_dataset.select_columns(["positive", "anchor"]),  # training datasetloss=train_loss,evaluator=evaluator
)trainer.train()
trainer.save_model()
print(f"cost time: {time.time() - start_time:.2f}s")

笔者在1张NVIDIA A800-SXM4-80GB型号的GPU上进行训练,耗时约63.10秒。同时,我们会将微调后的Embedding模型保存在GPU上。

总结

本文重点介绍了如何使用 Sentence Transformers 对开源的Embedding模型bge-base-zh-v1.5进行微调,并验证Embedding模型微调后的效果。

Sentence Transformers 是一个宝库,它介绍了关于Embedding模型方方面面的内容,是了解、深入Embedding模型必不可少的工具。后续笔者将会介绍Embedding模型量化、俄罗斯套娃嵌入模型(Matryoshka Representation Learning, MRL)等相关方面的内容。

参考文献

  1. Training and Finetuning Embedding Models with Sentence Transformers v3: https://huggingface.co/blog/train-sentence-transformers

  2. Fine-tune Embedding models for Retrieval Augmented Generation (RAG): https://www.philschmid.de/fine-tune-embedding-model-for-rag

  3. 俄罗斯套娃 (Matryoshka) 嵌入模型概述: https://huggingface.co/blog/zh/matryoshka

  4. Finetune Embeddings: https://docs.llamaindex.ai/en/stable/examples/finetuning/embeddings/finetune_embedding/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/850011.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue的基础知识:v-model的原理,由:value与@input合写。

原理&#xff1a;v-model本质上是一个语法糖&#xff0c;比如应用在输入框上&#xff0c;就是value属性和input事件的合写。&#xff08;补充说明&#xff1a;语法糖就是语法的简写&#xff09; 作用&#xff1a;提供数据的双向绑定 1.数据变&#xff0c;视图&#xff08;也就…

[数据集][目标检测]叶子计数检测数据集VOC+YOLO格式240张1类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;240 标注数量(xml文件个数)&#xff1a;240 标注数量(txt文件个数)&#xff1a;240 标注类别…

2024年谷歌SEO如何快速出排名效果抢占首页制高点?

2024年&#xff0c;随着谷歌搜索引擎算法的不断更新&#xff0c;SEO策略也需要与时俱进才能快速出排名。本文将结合谷歌最新SEO趋势&#xff0c;平哥SEO分享一些实操性的快速排名技巧&#xff0c;帮助你在竞争激烈的搜索结果中脱颖而出。 额外话题&#xff1a;就是通过微信公众…

break、continue、return

break 程序示例&#xff1a; // 产生一个位于 [1, 100] 范围内的随机数&#xff0c;统计产生 100 所需要的次数 public static void main(String[] args) {// System.out.println(Math.random()); // [0,1)// System.out.println(Math.random() * 100); // [0,100)// Syste…

什么是XSS攻击?什么是SQL注入攻击?什么是CSRF攻击?

XSS攻击、SQL注入攻击和CSRF攻击是三种常见的网络安全威胁&#xff0c;它们分别针对不同的应用层面和安全漏洞。以下是对这三种攻击方式的详细介绍&#xff1a; 1. XSS攻击&#xff08;跨站脚本攻击&#xff0c;Cross-Site Scripting&#xff09; 业务场景&#xff1a; 用户…

Java Web学习笔记25——Vue组件库Element

什么是Element&#xff1f; Element: 是饿了么团队研发的&#xff0c;一套为开发者、设计师和产品经理准备的基于Vue2.0的桌面端组件库。 组件&#xff1a;组成网页的部件&#xff0c;例如&#xff1a;超链接、按钮、图片、表格、表单、分页条等等。 官网&#xff1a;https:…

C++设计模式---工厂模式

C中工厂模式是一种创建型设计模式&#xff0c;它允许客户端代码通过调用工厂方法来创建对象&#xff0c;而无需直接使用new运算符实例化具体类。这种模式有助于将类的创建与使用相分离&#xff0c;并且在需要添加新的具体类时可以减少对客户端代码的影响。 工厂模式通常有两种实…

深拷贝、浅拷贝、引用拷贝

深拷贝和浅拷贝的区别 1. 引用拷贝2. 对象拷贝 1. 引用拷贝 两个对象指向同一个地址值。 创建一个指向对象的引用变量的拷贝Teacher teacher new Teacher("Taylor",26); Teacher otherteacher teacher; System.out.println(teacher); System.out.println(otherte…

cocos入门8:向量叉乘

在cocos creator中&#xff0c;向量叉乘&#xff08;Cross Product&#xff09;是一个重要的概念&#xff0c;主要用于3D图形学中的方向计算和法线计算。叉乘的结果是一个垂直于两个输入向量的新向量&#xff0c;其长度等于输入向量围成的平行四边形的面积。以下是对向量叉乘的…

前端多人项目开发中,如何保证CSS样式不冲突?

在前端项目开发中&#xff0c;例如突然来了一个大项目&#xff0c;很可能就需要多人一起开发&#xff0c;领导说了&#xff0c;要快&#xff0c;要快&#xff0c;要快&#xff0c;你们给我快。然后下面大伙就一拥而上&#xff0c;干着干着发现&#xff0c;一更新代码&#xff0…

【AI论文与新生技术】Follow-Your-Emoji:精细可控且富有表现力的自由式人像动画技术

我们提出了 Follow-Your-Emoji&#xff0c;这是一种基于扩散的肖像动画框架&#xff0c;它使用目标地标序列对参考肖像进行动画处理。肖像动画的主要挑战是保留参考肖像的身份并将目标表情转移到该肖像&#xff0c;同时保持时间一致性和保真度。为了应对这些挑战&#xff0c;Fo…

JFinal学习07 控制器——接收数据之getBean()和getModel()

JFinal学习07 控制器——接收数据之getBean()和getModel() 视频来源https://www.bilibili.com/video/BV1Bt411H7J9/?spm_id_from333.337.search-card.all.click 文章目录 JFinal学习07 控制器——接收数据之getBean()和getModel()一、接收数据的类型二、getBean()和getModel()…

GDPU JavaWeb Ajax请求

异步请求可以提升用户体验并优化页面性能。 ajax登录 实现ajax异步登录。 注意&#xff0c;ajax用到了jQuery库&#xff0c;先下载好相应的js库&#xff0c;然后复制导入到工程的web目录下&#xff0c;最好与你的前端页面同一层级。然后编写时路径一定要找准&#xff0c;“pag…

WinRAR安装教程

WinRAR安装教程 1. 下载WinRAR 访问官方网站&#xff1a;打开浏览器&#xff0c;访问WinRAR的官方网站&#xff08;如&#xff1a;www.winrar.com.cn&#xff09;。选择版本&#xff1a;根据您的操作系统&#xff08;32位或64位&#xff09;选择合适的WinRAR版本。下载软件&a…

转让北京公司带旅行许可证流程和要求

旅行社是开展旅游服务业务的专项经济主体&#xff0c;旅行社开展相关业务必须持有旅行社业务经营许可证。该资质又分为国内旅行社经营许可证和出境旅行社经营许可证。主要区别在于能否开展出境旅游业务&#xff0c;下面老耿带大家了解&#xff0c;新成立国内旅行社要求及出境旅…

python-windows10普通笔记本跑bert mrpc数据样例0.1.001

python-windows10普通笔记本跑bert mrpc数据样例0.1.000 背景参考章节获取数据下载bert模型下载bert代码windows10的cpu执行结果注意事项TODOLIST背景 看了介绍说可以在gpu或者tpu上去微调,当前没环境,所以先在windows10上跑一跑,看是否能顺利进行,目标就是训练的过程中没…

【Vue2/3】使用Provide/Inject 依赖注入跨组件通信

历史小剧场 什么东西&#xff0c;都有使用年限&#xff0c;比如大米、王朝。 不同的是&#xff0c;大米的年限看得见&#xff0c;王朝的年限看不见。看不见&#xff0c;却依然存在。对于气数&#xff0c;崇祯是不信的&#xff0c;开始不信。等到崇祯十四年&#xff0c;怕什么来…

js--hasOwnProperty()讲解与使用

@TOC 前言 hasOwnProperty(propertyName)方法 是用来检测属性是否为对象的自有属性 object.hasOwnProperty(propertyName) // true/false 讲解 hasOwnProperty() 方法是 Object 的原型方法(也称实例方法),它定义在 Object.prototype 对象之上,所有 Object 的实例对象都会继…

6.7 输入输出流

输入&#xff1a;将数据放到程序&#xff08;内存&#xff09;中 输出&#xff1a;将数据从程序&#xff08;内存&#xff09;放到设备中 C的输入输出分为3种形式&#xff1a; 从键盘屏幕中输入输出&#xff0c;称为标准IO 对于磁盘进行标准输入输出&#xff0c;称为文件IO…

go 读取json文件内容,并且解析内容到interface、 map、 struct

1&#xff0c;解析到interface、 map func ReadAllFileContent(fileName string) {file, err : os.Open(fileName)if err ! nil {log.Fatal(err)}defer file.Close()// buf : make([]byte, 2024)data, err : ioutil.ReadAll(file) //读取的结果是[]byte类型if err ! nil {log.…