动手做一个最小RAG——TinyRAG

 Datawhale干货 

作者:宋志学,Datawhale成员

大家好,我是不要葱姜蒜。

接下来我会带领大家一步一步地实现一个简单的RAG模型,这个模型是基于RAG的一个简化版本,我们称之为Tiny-RAG。Tiny-RAG是一个基于RAG的简化版本,它只包含了RAG的核心功能,即Retrieval和Generation。Tiny-RAG的目的是为了帮助大家更好地理解RAG模型的原理和实现。

OK,让我们开始吧!

1. RAG 介绍

LLM会产生误导性的 “幻觉”,依赖的信息可能过时,处理特定知识时效率不高,缺乏专业领域的深度洞察,同时在推理能力上也有所欠缺。

正是在这样的背景下,检索增强生成技术(Retrieval-Augmented Generation,RAG)应时而生,成为 AI 时代的一大趋势。

RAG 通过在语言模型生成答案之前,先从广泛的文档数据库中检索相关信息,然后利用这些信息来引导生成过程,极大地提升了内容的准确性和相关性。RAG 有效地缓解了幻觉问题,提高了知识更新的速度,并增强了内容生成的可追溯性,使得大型语言模型在实际应用中变得更加实用和可信。

RAG的基本结构有哪些呢?

  • 要有一个向量化模块,用来将文档片段向量化。

  • 要有一个文档加载和切分的模块,用来加载文档并切分成文档片段。

  • 要有一个数据库来存放文档片段和对应的向量表示。

  • 要有一个检索模块,用来根据 Query (问题)检索相关的文档片段。

  • 要有一个大模型模块,用来根据检索出来的文档回答用户的问题。

OK,那上述这些也就是 TinyRAG 仓库的所有模块内容。

a3405c63b7919a09fc13eb45eb6a2d12.png

那接下来,让我们梳理一下 RAG 的流程是什么样的呢?

  • 索引:将文档库分割成较短的 Chunk,并通过编码器构建向量索引。

  • 检索:根据问题和 chunks 的相似度检索相关文档片段。

  • 生成:以检索到的上下文为条件,生成问题的回答。

那也就是下图所示的流程,图片出处  Retrieval-Augmented Generation for Large Language Models: A Survey

2db79056383ec70629633a18df5ece49.png

2. 向量化

首先让我们来动手实现一个向量化的类,这是RAG架构的基础。向量化的类主要是用来将文档片段向量化,将一段文本映射为一个向量。

那首先我们要设置一个 Embedding 基类,这样我们再用其他的模型的时候,只需要继承这个基类,然后在此基础上进行修改即可,方便代码扩展。

class BaseEmbeddings:"""Base class for embeddings"""def __init__(self, path: str, is_api: bool) -> None:self.path = pathself.is_api = is_apidef get_embedding(self, text: str, model: str) -> List[float]:raise NotImplementedError@classmethoddef cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:"""calculate cosine similarity between two vectors"""dot_product = np.dot(vector1, vector2)magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)if not magnitude:return 0return dot_product / magnitude

观察一下BaseEmbeddings基类都有什么方法,首先有一个get_embedding方法,这个方法是用来获取文本的向量表示的,然后有一个cosine_similarity方法,这个方法是用来计算两个向量之间的余弦相似度的。其次在初始化类的时候设置了,模型的路径或者是否是API模型。比如使用OpenAI的Embedding API的话就需要设置self.is_api=Ture

继承BaseEmbeddings类的话,就只需要编写get_embedding方法即可,cosine_similarity方法会被继承下来,直接用就行。这就是编写基类的好处。

class OpenAIEmbedding(BaseEmbeddings):"""class for OpenAI embeddings"""def __init__(self, path: str = '', is_api: bool = True) -> None:super().__init__(path, is_api)if self.is_api:from openai import OpenAIself.client = OpenAI()self.client.api_key = os.getenv("OPENAI_API_KEY")self.client.base_url = os.getenv("OPENAI_BASE_URL")def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:if self.is_api:text = text.replace("\n", " ")return self.client.embeddings.create(input=[text], model=model).data[0].embeddingelse:raise NotImplementedError

3. 文档加载和切分

接下来我们来实现一个文档加载和切分的类,这个类主要是用来加载文档并切分成文档片段。

那我们都需要切分什么文档呢?这个文档可以是一篇文章,一本书,一段对话,一段代码等等。这个文档的内容可以是任何的,只要是文本就行。比如:pdf文件、md文件、txt文件等等。

这里只展示一部分内容了,完整的代码可以在 RAG/utils.py 文件中找到。在这个代码中可以看到,能加载的文件类型有:pdf、md、txt,只需要编写对应的函数即可。

def read_file_content(cls, file_path: str):# 根据文件扩展名选择读取方法if file_path.endswith('.pdf'):return cls.read_pdf(file_path)elif file_path.endswith('.md'):return cls.read_markdown(file_path)elif file_path.endswith('.txt'):return cls.read_text(file_path)else:raise ValueError("Unsupported file type")

那我们把文件内容都读取之后,还需要切分呀!那怎么切分呢,OK,接下来咱们就按 Token 的长度来切分文档。我们可以设置一个最大的 Token 长度,然后根据这个最大的 Token 长度来切分文档。这样切分出来的文档片段就是一个一个的差不多相同长度的文档片段了。

不过在切分的时候要注意,片段与片段之间最好要有一些重叠的内容,这样才能保证检索的时候能够检索到相关的文档片段。还有就是切分文档的时候最好以句子为单位,也就是按 \n 进行粗切分,这样可以基本保证句子内容是完整的。

def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150):chunk_text = []curr_len = 0curr_chunk = ''lines = text.split('\n')  # 假设以换行符分割文本为行for line in lines:line = line.replace(' ', '')line_len = len(enc.encode(line))if line_len > max_token_len:print('warning line_len = ', line_len)if curr_len + line_len <= max_token_len:curr_chunk += linecurr_chunk += '\n'curr_len += line_lencurr_len += 1else:chunk_text.append(curr_chunk)curr_chunk = curr_chunk[-cover_content:]+linecurr_len = line_len + cover_contentif curr_chunk:chunk_text.append(curr_chunk)return chunk_text

4. 数据库 && 向量检索

上面,我们做好了文档切分,也做好了 Embedding 模型的加载。那接下来就得设计一个向量数据库用来存放文档片段和对应的向量表示了。

还有就是也要设计一个检索模块,用来根据 Query (问题)检索相关的文档片段。OK,我们冲冲冲!

一个数据库对于最小RAG架构来说,需要实现几个功能呢?

  • persist:数据库持久化,本地保存

  • load_vector:从本地加载数据库

  • get_vector:获得文档的向量表示

  • query:根据问题检索相关的文档片段

嗯嗯,以上四个模块就是一个最小的RAG结构数据库需要实现的功能,具体代码可以在 RAG/VectorBase.py 文件中找到。

class VectorStore:def __init__(self, document: List[str] = ['']) -> None:self.document = documentdef get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:# 获得文档的向量表示passdef persist(self, path: str = 'storage'):# 数据库持久化,本地保存passdef load_vector(self, path: str = 'storage'):# 从本地加载数据库passdef query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:# 根据问题检索相关的文档片段pass

那让我们来看一下, query 方法具体是怎么实现的呢?

首先先把用户提出的问题向量化,然后去数据库中检索相关的文档片段,最后返回检索到的文档片段。可以看到咱们在向量检索的时候仅使用 Numpy 进行加速,代码非常容易理解和修改。

主要是方便改写和大家理解,并没有使用成熟的数据库,这样可以更好地理解RAG的原理。

def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:query_vector = EmbeddingModel.get_embedding(query)result = np.array([self.get_similarity(query_vector, vector)for vector in self.vectors])return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()

5. 大模型模块

那就来到了最后一个模块了,大模型模块。这个模块主要是用来根据检索出来的文档回答用户的问题。

一样的,我们还是先实现一个基类,这样我们在遇到其他的自己感兴趣的模型就可以快速的扩展了。

class BaseModel:def __init__(self, path: str = '') -> None:self.path = pathdef chat(self, prompt: str, history: List[dict], content: str) -> str:passdef load_model(self):pass

BaseModel 包含了两个方法,chatload_model,如果使用API模型,比如OpenAI的话,那就不需要load_model方法,如果你要本地化运行的话,那还是会选择使用开源模型,那就需要load_model方法啦。

这里咱们以 InternLM2-chat-7B 模型为例

class InternLMChat(BaseModel):def __init__(self, path: str = '') -> None:super().__init__(path)self.load_model()def chat(self, prompt: str, history: List = [], content: str='') -> str:prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)response, history = self.model.chat(self.tokenizer, prompt, history)return responsedef load_model(self):import torchfrom transformers import AutoTokenizer, AutoModelForCausalLMself.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()

可以用一个字典来保存所有的prompt,这样比较好维护。

PROMPT_TEMPLATE = dict(InternLM_PROMPT_TEMPALTE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。问题: {question}可参考的上下文:···{context}···如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。有用的回答:"""
)

那这样的话,我们就可以利用InternLM2模型来做RAG啦!

6.  LLM Tiny-RAG Demo

那接下来,我们就来看一下Tiny-RAG的Demo吧!

from RAG.VectorBase import VectorStore
from RAG.utils import ReadFiles
from RAG.LLM import OpenAIChat, InternLMChat
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding没有保存数据库
docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150) # 获得data目录下的所有文件内容并分割
vector = VectorStore(docs)
embedding = ZhipuEmbedding() # 创建EmbeddingModel
vector.get_vector(EmbeddingModel=embedding)
vector.persist(path='storage') # 将向量和文档内容保存到storage目录下,下次再用就可以直接加载本地的数据库question = 'git的原理是什么?'content = vector.query(question, model='zhipu', k=1)[0]
chat = InternLMChat(path='model_path')
print(chat.chat(question, [], content))

当然我们也可以从本地加载已经处理好的数据库,毕竟我们在上面的数据库环节已经写过这个功能啦。

from RAG.VectorBase import VectorStore
from RAG.utils import ReadFiles
from RAG.LLM import OpenAIChat, InternLMChat
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding# 保存数据库之后
vector = VectorStore()vector.load_vector('./storage') # 加载本地的数据库question = 'git的原理是什么?'embedding = ZhipuEmbedding() # 创建EmbeddingModelcontent = vector.query(question, EmbeddingModel=embedding, k=1)[0]
chat = InternLMChat(path='model_path')
print(chat.chat(question, [], content))

7. 总结

经过上面的学习,你是否学会了如何搭建一个最小RAG架构呢?相信你一定学会啦,哈哈哈。

那让我们再来复习一下,一个最小RAG应该包含哪些内容叭?(此处默写!)

  • 向量化模块

  • 文档加载和切分模块

  • 数据库

  • 向量检索

  • 大模型模块

okk,你已经学会了,但别忘了给我的项目点个star哦!

项目地址:https://github.com/KMnO4-zx/TinyRAG

efc2eef19a4ee8c4f20d16ebafa35f63.png
一起“赞”三连

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/736049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot源码

SpringBoot核心前置内容 1.Spring注解编程的发展过程 1.1 Spring 1.x 2004年3月24日&#xff0c;Spring1.0 正式发布&#xff0c;提供了IoC&#xff0c;AOP及XML配置的方式。 在Spring1.x版本中提供的是纯XML配置的方式&#xff0c;也就是在该版本中必须要提供xml的配置文件…

八、词嵌入语言模型(Word Embedding)

词嵌入&#xff08;Word Embedding, WE&#xff09;&#xff0c;任务是把不可计算、非结构化的词转换为可以计算、结构化的向量&#xff0c;从而便于进行数学处理。 一个更官方一点的定义是&#xff1a;词嵌入是是指把一个维数为所有词的数量的高维空间&#xff08;one-hot形式…

小迪安全36WEB 攻防-通用漏洞XSS 跨站MXSSUXSSFlashXSSPDFXSS

#XSS跨站系列内容:1. XSS跨站-原理&分类&手法 XSS跨站-探针&利用&审计XSS跨站另类攻击手法利用 XSS跨站-防御修复&绕过策略 #知识点&#xff1a; 1、XSS 跨站-原理&攻击&分类等 2、XSS 跨站-MXSS&UXSS&FlashXss&PDFXSS 等 1、原…

HCS-华为云Stack-计算节点内部网络结构

HCS-华为云Stack-计算节点内部网络结构 图中表示的仅为计算节点是两网口的模式&#xff0c;如果是四网口模式&#xff0c;系统会再自动创建一个网桥出来 图中未画出存储平面和Internal Base平面&#xff0c;它们和tunnel bearing、External OM-样&#xff0c;都是通过trunk0的…

信息系统项目管理师006:车联网(1信息化发展—1.2现代化基础设施—1.2.3车联网)

文章目录 1.2.3 车联网1.体系框架2.链接方式3.场景应用 记忆要点总结 1.2.3 车联网 车联网是新一代网络通信技术与汽车、电子、道路交通运输等领域深度融合的新兴产业形态。智能网联汽车是搭载先进的车载传感器、控制器、执行器等装置&#xff0c;并融合现代通信与网络技术&…

Linux常用命令之top监测

(/≧▽≦)/~┴┴ 嗨~我叫小奥 ✨✨✨ &#x1f440;&#x1f440;&#x1f440; 个人博客&#xff1a;小奥的博客 &#x1f44d;&#x1f44d;&#x1f44d;&#xff1a;个人CSDN ⭐️⭐️⭐️&#xff1a;传送门 &#x1f379; 本人24应届生一枚&#xff0c;技术和水平有限&am…

for、while、do While、for in、forEach、map、reduce、every、some、filter的使用

for、while、do While、for in、forEach、map、reduce、every、some、filter的使用 for let arr [2, 4, 6, 56, 7, 88];//for for (let i 0; i < arr.length; i) {console.log(i : arr[i]) //0:2 1:4 2:6 3:56 4:7 5:88 }普通的for循环可以用数组的索引来访问或者修改…

代码随想录day32 Java版

62.不同路径 public static int uniquePaths(int m, int n) {int[][] dp new int[m][n];//初始化for (int i 0; i < m; i) {dp[i][0] 1;}for (int i 0; i < n; i) {dp[0][i] 1;}for (int i 1; i < m; i) {for (int j 1; j < n; j) {dp[i][j] dp[i-1][j]dp…

Java享元模式源码剖析及使用场景

享元模式 一、介绍二、基本原理三、企业资源管理系统中使用案例三、Java 中的字符串常量池使用了享元模式四、总结优缺点以及使用经验 一、介绍 享元模式是一种结构型设计模式&#xff0c;旨在最大程度地减少内存使用或计算开销。这种模式通过共享对多个类似对象实例所需的状态…

04 数据结构之队列

循环队列 /* squence_queue.h */ #ifndef _SQUENCE_QUEUE_H_ #define _SQUENCE_QUEUE_H_#include <stdio.h> #include <string.h> #include <stdlib.h>#define QUEUE_SIZE 128 #define DEBUG(msg) \printf("--%s--, %s", __func__, msg)typedef i…

SAP BTP Hyperscaler PostgreSQL都有哪些Performance监控 (一)

前言 SAP BTP云平台中&#xff0c;除了自身的HANA数据库作为首选以外&#xff0c;它还支持PostgreSQL的整套服务&#xff0c;并以PaaS的形式提供给客户。你可以按照实例为单位进行购买申请不同标准规格的PG实例&#xff0c;然后构建自己的业务逻辑。Hyperscaler是这套产品或方…

【Python-Docx库】Word与Python的完美结合

今天给大家分享Python处理Word的第三方库&#xff1a;Python-Docx。 什么是Python-Docx&#xff1f; Python-Docx是用于创建和更新Microsoft Word&#xff08;.docx&#xff09;文件的Python库。 日常需要经常处理Word文档&#xff0c;用Python的免费第三方包&#xff1a;Pyt…

【Linux】Shell及Linux权限

Shell Shell的定义 Shell最简单的定义是&#xff1a;命令行解释器。 Shell的主要任务&#xff1a;1. 将使用者的命令翻译给核心进行处理。2.将核心的处理结果翻译给使用者 为什么要有Shell? 使用者和内核的关系就相当于两个完全陌生的外国人之间的关系&#xff0c;他们要进…

springboot、vue、uniapp项目的部署和运行(超链接可直接跳过去)

springboot、vue项目环境配置 1、首先要安装jdk、maven、mysql、nodejs 软件安装 2、安装idea、HbuilderX、navicat 运行项目 3、运行springboot项目、运行vue项目、运行uniapp项目

Dockerfile编写实践篇

Docker通过一种打包和分发的软件&#xff0c;完成传统容器的封装。这个用来充当容器分发角色的组件被称为镜像。Docker镜像是一个容器中运行程序的所有文件的捆绑快照。当使用Docker分发软件&#xff0c;其实就是分发这些镜像&#xff0c;并在接收的机器上创建容器。镜像在Dock…

Linux:线程互斥与同步

目录 线程互斥 锁的初始化 加锁 解锁 锁的初始化 锁的原理 死锁 线程同步 方案一&#xff1a;条件变量 条件变量初始化 等待 唤醒 条件变量的代码示例 基于阻塞队列的生产消费模型 方案二&#xff1a;POSIX信号量 初始化信号量&#xff1a; 销毁信号量 等待信…

JAVA基础-数据结构一(线性表、链表、栈、队列)

一、数组线性表&#xff08;ADT&#xff09; 线性表&#xff1a;又称动态数组&#xff0c;核心是动态数组&#xff0c;可以自行扩容&#xff0c;支持增删改查四种功能 java中有ArrayList也可以自行扩容&#xff0c;二者功能较为相似&#xff0c;且ArrayList也支持转换为数组。 …

中国大学生计算机设计大赛--智慧物流挑战赛基础

文章目录 一、Ubuntu基础1.1 基本操作1.2 文本编辑 二、ROS基础介绍2.1 概念与特点2.2 基本结构2.3 创建工程2.4 节点和节点管理器2.5 启动文件 三、ROS通信机制3.1 话题3.2 服务3.3 动作3.4 参数服务器 四、ROS可视化工具4.1 rviz4.2 rqt4.3 tf 五、Python实现简单的ROS节点程…

01-分析同步通讯/异步通讯的特点及其应用

同步通讯/异步通讯 微服务间通讯有同步和异步两种方式 同步通讯: 类似打电话场景需要实时响应(时效性强可以立即得到结果方便使用),而且通话期间不能响应其他的电话(不支持多线操作)异步通讯: 类似发邮件场景不需要马上回复并且可以多线操作(适合高并发场景)但是时效性弱响应…

MQ高可用相关设置

文章目录 前言MQ如何保证消息不丢失RabbitMQRocketMQKafkaMQ MQ如何保证顺序消息RabbitMQRocketMQKafka MQ刷盘机制/集群同步RabbitMQRocketMQKafka 广播消息&集群消息RabbitMQRocketMQ MQ集群架构RabbitMQRocketMQKafka 消息重试RabbitMQRockeMqKafka 死信队列RocketMQKaf…