LangChain之关于RetrievalQA input_variables 的定义与使用

最近在使用LangChain来做一个LLMs和KBs结合的小Demo玩玩,也就是RAG(Retrieval Augmented Generation)。
这部分的内容其实在LangChain的官网已经给出了流程图。在这里插入图片描述
我这里就直接偷懒了,准备对Webui的项目进行复刻练习,那么接下来就是照着葫芦画瓢就行。
那么我卡在了Retrieve这一步。先放有疑惑地方的代码:

if web_content:prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。已知网络检索内容:{web_content}""" + """已知内容:{context}问题:{question}"""else:prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。已知内容:{context}问题:{question}"""prompt = PromptTemplate(template=prompt_template,input_variables=["context", "question"])......knowledge_chain = RetrievalQA.from_llm(llm=self.llm,retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),prompt=prompt)knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(input_variables=["page_content"], template="{page_content}")knowledge_chain.return_source_documents = Trueresult = knowledge_chain({"query": query})return result

我对prompt_templateknowledge_chain.combine_documents_chain.document_prompt result = knowledge_chain({"query": query})这三个地方的input_key不明白为啥一定要这样设置。虽然我也看了LangChain的API文档。但是我并未得到详细的答案,那么只能一行行看源码是到底怎么设置的了。

注意:由于LangChain是一层层封装的,那么result = knowledge_chain({"query": query})可以认为是最外层,那么我们先看最外层。

result = knowledge_chain({“query”: query})

其实这部分是直接与用户的输入问题做对接的,我们只需要定位到RetrievalQA这个类就可以了,下面是RetrievalQA这个类的实现:

class RetrievalQA(BaseRetrievalQA):"""Chain for question-answering against an index.Example:.. code-block:: pythonfrom langchain.llms import OpenAIfrom langchain.chains import RetrievalQAfrom langchain.vectorstores import FAISSfrom langchain.schema.vectorstore import VectorStoreRetrieverretriever = VectorStoreRetriever(vectorstore=FAISS(...))retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)"""retriever: BaseRetriever = Field(exclude=True)def _get_docs(self,question: str,*,run_manager: CallbackManagerForChainRun,) -> List[Document]:"""Get docs."""return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())async def _aget_docs(self,question: str,*,run_manager: AsyncCallbackManagerForChainRun,) -> List[Document]:"""Get docs."""return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())@propertydef _chain_type(self) -> str:"""Return the chain type."""return "retrieval_qa"

可以看到其继承了BaseRetrievalQA这个父类,同时对_get_docs这个抽象方法进行了实现。

这里要扩展的说一下,_get_docs这个方法就是利用向量相似性,在vector Base中选择与embedding之后的query最近似的Document结果。然后作为RetrievalQA的上下文。具体只需要看BaseRetrievalQA这个方法的_call和就可以了。
接下来我们只需要看BaseRetrievalQA这个类的属性就可以了。

class BaseRetrievalQA(Chain):"""Base class for question-answering chains."""combine_documents_chain: BaseCombineDocumentsChain"""Chain to use to combine the documents."""input_key: str = "query"  #: :meta private:output_key: str = "result"  #: :meta private:return_source_documents: bool = False"""Return the source documents or not."""……def _call(self,inputs: Dict[str, Any],run_manager: Optional[CallbackManagerForChainRun] = None,) -> Dict[str, Any]:"""Run get_relevant_text and llm on input query.If chain has 'return_source_documents' as 'True', returnsthe retrieved documents as well under the key 'source_documents'.Example:.. code-block:: pythonres = indexqa({'query': 'This is my query'})answer, docs = res['result'], res['source_documents']"""_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()question = inputs[self.input_key]accepts_run_manager = ("run_manager" in inspect.signature(self._get_docs).parameters)if accepts_run_manager:docs = self._get_docs(question, run_manager=_run_manager)else:docs = self._get_docs(question)  # type: ignore[call-arg]answer = self.combine_documents_chain.run(input_documents=docs, question=question, callbacks=_run_manager.get_child())if self.return_source_documents:return {self.output_key: answer, "source_documents": docs}else:return {self.output_key: answer}

可以看到其有input_key这个属性,默认值是"query"。到这里我们就可以看到result = knowledge_chain({"query": query})是调用的BaseRetrievalQA_call,这里的question = inputs[self.input_key]就是其体现。

knowledge_chain.combine_documents_chain.document_prompt

这个地方一开始我很奇怪,为什么会重新定义呢?
我们可以先定位到,combine_documents_chain这个参数的位置,其是StuffDocumentsChain的方法。

@classmethod
def from_llm(cls,llm: BaseLanguageModel,prompt: Optional[PromptTemplate] = None,callbacks: Callbacks = None,**kwargs: Any,
) -> BaseRetrievalQA:"""Initialize from LLM."""_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks)document_prompt = PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")combine_documents_chain = StuffDocumentsChain(llm_chain=llm_chain,document_variable_name="context",document_prompt=document_prompt,callbacks=callbacks,)return cls(combine_documents_chain=combine_documents_chain,callbacks=callbacks,**kwargs,)

可以看到原始的document_prompt中PromptTemplate的template是“Context:\n{page_content}”。因为这个项目是针对中文的,所以需要将英文的Context去掉。

扩展

  1. 这里PromptTemplate(input_variables=[“page_content”], template=“Context:\n{page_content}”)的input_variablestemplate为什么要这样定义呢?其实是根据Document这个数据对象来定义使用的,我们可以看到其数据格式为:Document(page_content=‘……’, metadata={‘source’: ‘……’, ‘row’: ……})
    那么input_variables的输入就是Document的page_content。
  2. StuffDocumentsChain中有一个参数是document_variable_name。那么这个类是这样定义的This chain takes a list of documents and first combines them into a single string. It does this by formatting each document into a string with the document_prompt and then joining them together with document_separator. It then adds that new string to the inputs with the variable name set by document_variable_name. Those inputs are then passed to the llm_chain. 这个document_variable_name简单来说就是在document_prompt中的占位符,用于在Chain中的使用。
    因此我们上文prompt_template变量中的“已知内容: {context}”,用的就是context这个变量。因此在prompt_template中换成其他的占位符都不能正常使用这个Chain。

prompt_template

在上面的拓展中其实已经对prompt_template做了部分的讲解,那么这个字符串还剩下“问题:{question}”这个地方没有说通
还是回归源码:

return cls(combine_documents_chain=combine_documents_chain,callbacks=callbacks,**kwargs,)

我们可以在from_llm函数中看到其返回值是到了_call,那么剩下的我们来看这个函数:


......
uestion = inputs[self.input_key]
accepts_run_manager = ("run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:docs = self._get_docs(question, run_manager=_run_manager)
else:docs = self._get_docs(question)  # type: ignore[call-arg]
answer = self.combine_documents_chain.run(input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
......

这里是在run这个函数中传入了一个字典值,这个字典值有三个参数。

注意:

  1. 这三个参数就是kwargs,也就是_validate_inputs的参数input;
  2. 此时已经是在Chain这个基本类了)
def run(self,*args: Any,callbacks: Callbacks = None,tags: Optional[List[str]] = None,metadata: Optional[Dict[str, Any]] = None,**kwargs: Any,) -> Any:"""Convenience method for executing chain.The main difference between this method and `Chain.__call__` is that thismethod expects inputs to be passed directly in as positional arguments orkeyword arguments, whereas `Chain.__call__` expects a single input dictionarywith all the inputs"""

接下来调用__call__:

def __call__(self,inputs: Union[Dict[str, Any], Any],return_only_outputs: bool = False,callbacks: Callbacks = None,*,tags: Optional[List[str]] = None,metadata: Optional[Dict[str, Any]] = None,run_name: Optional[str] = None,include_run_info: bool = False,) -> Dict[str, Any]:"""Execute the chain.Args:inputs: Dictionary of inputs, or single input if chain expectsonly one param. Should contain all inputs specified in`Chain.input_keys` except for inputs that will be set by the chain'smemory.return_only_outputs: Whether to return only outputs in theresponse. If True, only new keys generated by this chain will bereturned. If False, both input keys and new keys generated by thischain will be returned. Defaults to False.callbacks: Callbacks to use for this chain run. These will be called inaddition to callbacks passed to the chain during construction, but onlythese runtime callbacks will propagate to calls to other objects.tags: List of string tags to pass to all callbacks. These will be passed inaddition to tags passed to the chain during construction, but onlythese runtime tags will propagate to calls to other objects.metadata: Optional metadata associated with the chain. Defaults to Noneinclude_run_info: Whether to include run info in the response. Defaultsto False.Returns:A dict of named outputs. Should contain all outputs specified in`Chain.output_keys`."""inputs = self.prep_inputs(inputs)......

这里的prep_inputs会调用_validate_inputs函数

def _validate_inputs(self,inputs: Dict[str, Any]) -> None:"""Check that all inputs are present."""missing_keys = set(self.input_keys).difference(inputs)if missing_keys:raise ValueError(f"Missing some input keys: {missing_keys}")

这里的input_keys通过调试,看到的就是有多个输入,分别是"input_documents"和"question"
这里的"input_documents"是来自于BaseCombineDocumentsChain

class BaseCombineDocumentsChain(Chain, ABC):"""Base interface for chains combining documents.Subclasses of this chain deal with combining documents in a variety ofways. This base class exists to add some uniformity in the interface these typesof chains should expose. Namely, they expect an input key related to the documentsto use (default `input_documents`), and then also expose a method to calculatethe length of a prompt from documents (useful for outside callers to use todetermine whether it's safe to pass a list of documents into this chain or whetherthat will longer than the context length)."""input_key: str = "input_documents"  #: :meta private:output_key: str = "output_text"  #: :meta private:

那为什么有两个呢,“question”来自于哪里?
StuffDocumentsChain继承BaseCombineDocumentsChain,其input_key是这样定义的:

  @propertydef input_keys(self) -> List[str]:extra_keys = [k for k in self.llm_chain.input_keys if k != self.document_variable_name]return super().input_keys + extra_keys

原来是重写了input_keys函数,其是对llm_chain的input_keys进行遍历。

那么llm_chain的input_keys是用其prompt的input_variables。(这里的input_variables是PromptTemplate中的[“context”, “question”])

	@propertydef input_keys(self) -> List[str]:"""Will be whatever keys the prompt expects.:meta private:"""return self.prompt.input_variables

至此,我们StuffDocumentsChain的input_keys有两个变量了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/135653.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

多语言多商户多货币跨境电商商城源码(一键铺货\订单返现商城源码搭建开发)

随着全球化的加速和互联网的蓬勃发展,跨境电商已成为越来越多企业的必经之路。如何在竞争激烈的市场中脱颖而出,实现多语言、多商户的跨境商城运营,成为了很多电商企业亟待解决的问题。今天,我们将为您揭示一款多语言多商户跨境商…

2023年11月数据库流行度最新排名

点击查看最新数据库流行度最新排名(每月更新) 2023年11月数据库流行度最新排名 TOP DB顶级数据库索引是通过分析在谷歌上搜索数据库名称的频率来创建的 一个数据库被搜索的次数越多,这个数据库就被认为越受欢迎。这是一个领先指标。原始数…

开源DB-GPT实现连接数据库详细步骤

官方文档:欢迎来到DB-GPT中文文档 — DB-GPT 👏👏 0.4.1 第一步:安装Minicoda https://docs.conda.io/en/latest/miniconda.html 第二步:安装Git Git - Downloading Package 第三步:安装embedding 模型到…

seata事务回滚引起的skywalking数据库存储空间剧增的问题排查

基本信息 产品名称:ATS3.0 问题分类:编码问题 环境类型:环境无关 问题现象 11月1日上午华润DBA收到数据库磁盘空间告警,检查后发现skywalking连接的mysql数据库占用空间从之前一直是比较稳定的,但是10月31日…

antd Form 校验自定义复杂判断-validator

antd Form 校验 加入自定义复杂逻辑 <Form.Itemlabel"编码"name"code"rules{[{required: true,validator: (_rule, value) > {if (value ) {return Promise.reject(请输入编码);}return IsExist(value).then((res) > {if (res?.statusCode 20…

强化学习中广义策略迭代

一、广义策略迭代 策略迭代包括两个同时进行的交互过程&#xff0c;一个使价值函数与当前策略保持一致&#xff08;策略评估&#xff09;&#xff0c;另一个使策略在当前价值函数下变得贪婪&#xff08;策略改进&#xff09;。在策略迭代中&#xff0c;这两个过程交替进行&…

【Qt之QAssociativeIterable】使用

介绍 QAssociativeIterable类是QVariant中一个关联式容器的可迭代接口。这个类允许多种访问在QVariant中保存的关联式容器元素的方法。如果一个QVariant可以转换为QVariantHash或QVariantMap&#xff0c;那么QAssociativeIterable的实例可以从中提取出来。 QHash<int, QSt…

软件版本控制系统VCS工具——cvs vss svn git

版本控制 版本控制系统&#xff08;Version Control System&#xff0c;VCS&#xff09;是用于跟踪和管理源代码和文档的工具。可追踪和管理修改历史&#xff0c;包括修改的内容、时间、作者等信息。有助于团队协作、追踪变更、恢复历史版本等。VCS的主要目的是帮助团队协作开…

电脑如何截屏?一起来揭晓答案!

在数字时代&#xff0c;截屏已经成为我们日常生活和工作中的必备技能。无论是为了捕捉有趣的网络瞬间&#xff0c;保存重要信息&#xff0c;还是为了协作和教育&#xff0c;电脑截屏都是一个强大而方便的工具。本文将介绍三种电脑如何截屏的方法&#xff0c;以满足各种需求&…

研发管理工具选型要考虑哪些内容?

研发管理工具选型要考虑哪些内容&#xff1f; 研发管理工具选型需要考虑六个因素&#xff0c;分别是&#xff1a;1、功能性&#xff1b;2、非功能性&#xff1b;3、易用性&#xff1b;4、产品价格&#xff1b;5、服务&#xff1b;6、厂商。其中功能性在研发管理工具选型过程中是…

精美好看又便于分享的电子相册制作,谁看了不心动呢?

很多人都喜欢用相机记录生活中的点点滴滴&#xff0c;可是当要分享到朋友圈的时候&#xff0c;觉得这张也好看&#xff0c;那张也不错&#xff0c;如果全部分享出去就霸屏了&#xff0c;然后就不知道怎么选择了。其实&#xff0c;我们可以把这些照片做成电子相册&#xff0c;然…

docker可视化

什么是portainer&#xff1f; portainer就是docker图形化界面的管理工具&#xff0c;提供一个后台面板供我们操作 目前先用portainer(先用这个)&#xff0c;以后还会用到Rancher(CI/CD在用) 1.下载portainer 9000是内网端口&#xff0c;8088是外网访问端口 docker run…

对话凯文·凯利:AI 会取代人的 90% 技能,并放大剩余的 10%

采访 | 邹欣&#xff0c;CSDN 副总裁 作者 | 王启隆 责编 | 唐小引 出品 | 《新程序员》编辑部 5000 天后&#xff0c;你都会做些什么&#xff1f; 是和 AI 助手一起编程&#xff0c;还是让生活完全由 AI 掌控&#xff0c;自己坐享其成&#xff1f;如果到时候还要上班&a…

采购劳保鞋如何选择合适的尺码

今天在某问答平台看到了这么一个话题&#xff0c;平常皮鞋穿40码&#xff0c;运动鞋穿41码&#xff0c;劳保鞋如何选择合适的尺码&#xff1f;小编发现很多朋友在选购劳保鞋的时候&#xff0c;对劳保鞋的尺码了解不是很清楚都会在这一块纠结。选择鞋子脚感舒适很重要&#xff0…

射频功率放大器应用中GaN HEMT的表面电势模型

标题&#xff1a;A surface-potential based model for GaN HEMTs in RF power amplifier applications 来源&#xff1a;IEEE IEDM 2010 本文中的任何第一人称都为论文的直译 摘要&#xff1a;我们提出了第一个基于表面电位的射频GaN HEMTs紧凑模型&#xff0c;并将我们的工…

如何在Linux上搭建本地Docker Registry并实现远程连接

Linux 本地 Docker Registry本地镜像仓库远程连接 文章目录 Linux 本地 Docker Registry本地镜像仓库远程连接1. 部署Docker Registry2. 本地测试推送镜像3. Linux 安装cpolar4. 配置Docker Registry公网访问地址5. 公网远程推送Docker Registry6. 固定Docker Registry公网地址…

C# OpenCvSharp 通过特征点匹配图片

SIFT匹配 SURF匹配 项目 代码 using OpenCvSharp; using OpenCvSharp.Extensions; using System; using System.Collections.Generic; using System.Drawing; using System.Linq; using System.Text.RegularExpressions; using System.Windows.Forms; using static System.Net…

企业防范数据安全的重要性与策略

随着信息技术的快速发展&#xff0c;企业的数据安全问题日益凸显。数据安全不仅关乎企业的商业机密&#xff0c;还涉及到客户的隐私和信任。因此&#xff0c;企业必须采取有效的防范措施&#xff0c;确保数据安全。本文将探讨企业防范数据安全的重要性&#xff0c;并介绍一些实…

【ARM Coresight OpenOCD 系列 1 -- OpenOCD 介绍】

请阅读【ARM Coresight SoC-400/SoC-600 专栏导读】 文章目录 1.1 OpenOCD 介绍1.1.1 OpenOCD 支持的JTAG 适配器1.1.2 OpenOCD 支持的调试设备1.1.3 OpenOCD 支持的 Flash 驱动 1.2 OpenOCD 安装与使用1.2.1 OpenOCD 代码获取及安装1.2.2 OpenOCD 使用1.2.3 OpenOCD 启用 GDB…

Vim基本使用操作

前言&#xff1a;作者也是初学Linux&#xff0c;可能总结的还不是很到位 Linux修炼功法&#xff1a;初阶功法 ♈️今日夜电波&#xff1a;美人鱼—林俊杰 0:21━━━━━━️&#x1f49f;──────── 4:14 …