LangGraph自适应RAG
- 介绍
- 索引
- LLMs
- web 搜索工具
- graph
- graph state
- graph flow
- build graph
- 执行
介绍
自适应 RAG 是一种 RAG 策略,它将 (1) 查询分析 (2) 主动/自校正 RAG 结合起来。
在文章中,他们报告了查询分析到路由获取:
- No Retrieval
- Single-shot RAG
- Iterative RAG
让我们使用 LangGraph 在此基础上进行构建。
在我们的实现中,我们将在以下之间进行路由:
- 网络搜索:与最近事件相关的问题
- 自校正 RAG:针对与我们的索引相关的问题
索引
from typing import List
import requests
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Vearchfrom langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (BaseModel
)
from langchain_text_splitters import RecursiveCharacterTextSplitterfrom common.constant import VEARCH_ROUTE_URL, BGE_M3_EMB_URLclass Bgem3Embeddings(BaseModel, Embeddings):def embed_documents(self, texts: List[str]) -> List[List[float]]:print(texts)return []def embed_query(self, text: str) -> List[float]:if not text:return []return cop_embeddings(text)"""
bg3m3转向量
"""def cop_embeddings(input: str) -> list:if not input.strip():return []headers = {"Content-Type": "application/json"}params = {"sentences": [input],"type": "dense"}response = requests.post(BGE_M3_EMB_URL, headers=headers, json=params)if response.status_code == 200:cop_embeddings_result = response.json()if not cop_embeddings_result or 'embeddings' not in cop_embeddings_result or not cop_embeddings_result['embeddings']:return []original_vector = cop_embeddings_result['embeddings'][0]original_size = len(original_vector)# 将1024的向量兼容为1536,适配openai向量接口adaptor_vector = [0.0] * 1536for i in range(min(original_size, 1536)):adaptor_vector[i] = original_vector[i]return adaptor_vectorelse:print(f"cop_embeddings error: {response.text}")return []# Docs to index
urls = ["https://lilianweng.github.io/posts/2023-06-23-agent/","https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/","https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]# 加载文档
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]# 文档分块
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)# 数据存储到向量库
embeddings_model = Bgem3Embeddings()
# embeddings_model, VEARCH_ROUTE_URL,"lanchain_autogpt","lanchain_autogpt_db", 3,
vectorstore = Vearch.from_documents(documents=doc_splits,embedding=embeddings_model,path_or_url=VEARCH_ROUTE_URL,table_name="lanchain_autogpt",db_name="lanchain_autogpt_db",flag=3
)
retriever = vectorstore.as_retriever()
LLMs
### Routerfrom typing import Literalfrom langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAIfrom common.common import PROXY_URL, API_KEY
from index1 import retriever# Data model
class RouteQuery(BaseModel):"""将用户查询路由到最相关的数据源。"""datasource: Literal["vectorstore", "web_search"] = Field(...,description="给定一个用户问题,选择将其发送到web_search或vectorstore。",)# LLM with function call
llm = ChatOpenAI(model_name="gpt-4o", api_key=API_KEY, base_url=PROXY_URL, temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)# Prompt
system = """你是将用户问题传送到vectorstore或web_search的专家。
vectorstore包含与agents、prompt engineering和adversarial attacks相关的文档。
使用向量库回答有关这些主题的问题。否则,请使用web_search。"""
route_prompt = ChatPromptTemplate.from_messages([("system", system),("human", "{question}"),]
)question_router = route_prompt | structured_llm_router# 案例1:数据源选择
# print(
# question_router.invoke(
# {"question": "谁将成为NFL选秀的第一人?"}
# )
# )
# print(question_router.invoke({"question": "Agent memory有哪些类型?"}))# Data model
class GradeDocuments(BaseModel):"""Binary score for relevance check on retrieved documents."""binary_score: str = Field(description="文档与问题相关, 'yes' or 'no'")structured_llm_grader = llm.with_structured_output(GradeDocuments)# Prompt
system = """你是一个评估检索到的文档和用户问题的相关性的分级员。 \n 如果文档包含与用户问题相关的关键字或语义,则将其评为相关。 \n它不需要是一个严格的测试。目标是过滤掉错误的检索。 \n给出二进制分数 'yes' or 'no' 表示文档是否与问题相关。"""
grade_prompt = ChatPromptTemplate.from_messages([("system", system),("human", "检索到的文档: \n\n {document} \n\n 用户问题: {question}"),]
)retrieval_grader = grade_prompt | structured_llm_grader
question = "agent memory"
docs = retriever.get_relevant_documents(question)
doc_txt = docs[1].page_content
# 案例2:检索评估
# print(retrieval_grader.invoke({"question": question, "document": doc_txt}))from langchain import hub
from langchain_core.output_parsers import StrOutputParser# Prompt
prompt = hub.pull("rlm/rag-prompt")# Post-processing
def format_docs(docs):return "\n\n".join(doc.page_content for doc in docs)# Chain
rag_chain = prompt | llm | StrOutputParser()# Run
# generation = rag_chain.invoke({"context": docs, "question": question})
# 案例3:prompt = hub.pull("rlm/rag-prompt") 生成
# print(generation)# Data model
class GradeHallucinations(BaseModel):"""Binary score for hallucination present in generation answer."""binary_score: str = Field(description="答案以事实为基础, 'yes' 或 'no'")structured_llm_grader = llm.with_structured_output(GradeHallucinations)# Prompt
system = """你是一名评估LLM生成是否以一组检索到的事实为基础/支持的分级员。 \n 给一个二进制分数 'yes' 或 'no'. 'Yes' 意味着答案以一系列事实为基础。"""
hallucination_prompt = ChatPromptTemplate.from_messages([("system", system),("human", "设置事实: \n\n {documents} \n\n 大模型生成: {generation}"),]
)hallucination_grader = hallucination_prompt | structured_llm_grader
# hgr = hallucination_grader.invoke({"documents": docs, "generation": generation})
# 案例4:幻觉评估
# print(hgr)### Answer Grader# Data model
class GradeAnswer(BaseModel):"""Binary score to assess answer addresses question."""binary_score: str = Field(description="答案是否解决了问题, 'yes' 或 'no'")structured_llm_grader = llm.with_structured_output(GradeAnswer)# Prompt你是一名评估答案是否能解决问题的评分员
system = """你是一名评估答案是否能解决问题的评分员 \n 给一个二进制分数 'yes' 或 'no'。 'Yes' 意味着答案解决了问题。"""
answer_prompt = ChatPromptTemplate.from_messages([("system", system),("human", "用户问题: \n\n {question} \n\n 大模型生成: {generation}"),]
)answer_grader = answer_prompt | structured_llm_grader
# agr = answer_grader.invoke({"question": question, "generation": generation})
# 案例5:答复评估
# print(agr)# Prompt
system = """你是一个问题重写器,可以将输入问题转换为更好的版本,该版本针对向量库检索进行了优化。
查看输入并尝试推理潜在的语义意图/含义。"""
re_write_prompt = ChatPromptTemplate.from_messages([("system", system),("human","下面是原始问题: \n\n {question} \n 提出一个改进的问题。",),]
)question_rewriter = re_write_prompt | llm | StrOutputParser()
# qrr = question_rewriter.invoke({"question": question})
# 案例6:问题重写
# print(qrr)
web 搜索工具
import osfrom langchain_community.tools.tavily_search import TavilySearchResultsfrom common.common import TAVILY_API_KEY# 提前通过 https://app.tavily.com/home 申请
os.environ["TAVILY_API_KEY"] = TAVILY_API_KEYtavily_tool = TavilySearchResults(k=3)
graph
将流程捕获为图表。
graph state
from typing import Listfrom typing_extensions import TypedDictclass GraphState(TypedDict):"""Represents the state of our graph.Attributes:question: questiongeneration: LLM generationdocuments: list of documents"""question: strgeneration: strdocuments: List[str]
graph flow
from langchain.schema import Document
from index1 import retriever
from llm2 import rag_chain, retrieval_grader, question_rewriter, question_router, hallucination_grader, answer_grader
from webstool3 import web_search_tool# 检索向量库中的doc
def retrieve(state):"""Retrieve documentsArgs:state (dict): The current graph stateReturns:state (dict): New key added to state, documents, that contains retrieved documents"""print("---RETRIEVE---")question = state["question"]# Retrievaldocuments = retriever.invoke(question)return {"documents": documents, "question": question}# 大模型生成
def generate(state):"""Generate answerArgs:state (dict): The current graph stateReturns:state (dict): New key added to state, generation, that contains LLM generation"""print("---GENERATE---")question = state["question"]documents = state["documents"]# RAG generationgeneration = rag_chain.invoke({"context": documents, "question": question})return {"documents": documents, "question": question, "generation": generation}# 文档评估
def grade_documents(state):"""Determines whether the retrieved documents are relevant to the question.Args:state (dict): The current graph stateReturns:state (dict): Updates documents key with only filtered relevant documents"""print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")question = state["question"]documents = state["documents"]# Score each docfiltered_docs = []for d in documents:score = retrieval_grader.invoke({"question": question, "document": d.page_content})grade = score.binary_scoreif grade == "yes":print("---GRADE: DOCUMENT RELEVANT---")filtered_docs.append(d)else:print("---GRADE: DOCUMENT NOT RELEVANT---")continuereturn {"documents": filtered_docs, "question": question}# 输入重写
def transform_query(state):"""Transform the query to produce a better question.Args:state (dict): The current graph stateReturns:state (dict): Updates question key with a re-phrased question"""print("---TRANSFORM QUERY---")question = state["question"]documents = state["documents"]# Re-write questionbetter_question = question_rewriter.invoke({"question": question})return {"documents": documents, "question": better_question}# web搜索
def web_search(state):"""Web search based on the re-phrased question.Args:state (dict): The current graph stateReturns:state (dict): Updates documents key with appended web results"""print("---WEB SEARCH---")question = state["question"]# Web searchdocs = web_search_tool.invoke({"query": question})web_results = "\n".join([d["content"] for d in docs])web_results = Document(page_content=web_results)return {"documents": web_results, "question": question}### Edges ###def route_question(state):"""Route question to web search or RAG.Args:state (dict): The current graph stateReturns:str: Next node to call"""print("---ROUTE QUESTION---")question = state["question"]source = question_router.invoke({"question": question})if source.datasource == "web_search":print("---ROUTE QUESTION TO WEB SEARCH---")return "web_search"elif source.datasource == "vectorstore":print("---ROUTE QUESTION TO RAG---")return "vectorstore"# 生成答案 还是 重新生成问题
def decide_to_generate(state):"""Determines whether to generate an answer, or re-generate a question.Args:state (dict): The current graph stateReturns:str: Binary decision for next node to call"""print("---ASSESS GRADED DOCUMENTS---")state["question"]filtered_documents = state["documents"]if not filtered_documents:# All documents have been filtered check_relevance# We will re-generate a new queryprint("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")return "transform_query"else:# We have relevant documents, so generate answerprint("---DECISION: GENERATE---")return "generate"# 确定生成是否基于文档并回答问题。
def grade_generation_v_documents_and_question(state):"""Determines whether the generation is grounded in the document and answers question.Args:state (dict): The current graph stateReturns:str: Decision for next node to call"""print("---CHECK HALLUCINATIONS---")question = state["question"]documents = state["documents"]generation = state["generation"]score = hallucination_grader.invoke({"documents": documents, "generation": generation})grade = score.binary_score# Check hallucinationif grade == "yes":print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")# Check question-answeringprint("---GRADE GENERATION vs QUESTION---")score = answer_grader.invoke({"question": question, "generation": generation})grade = score.binary_scoreif grade == "yes":print("---DECISION: GENERATION ADDRESSES QUESTION---")return "useful"else:print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")return "not useful"else:print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")return "not supported"
build graph
from langgraph.graph import END, StateGraphfrom pprint import pprintfrom common.common import show_img
from gflow5 import web_search, retrieve, grade_documents, generate, transform_query, route_question, decide_to_generate, \grade_generation_v_documents_and_question
from gstate4 import GraphState# 定义工作流
workflow = StateGraph(GraphState)# Define the nodes
workflow.add_node("web_search", web_search) # web search
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("transform_query", transform_query) # transform_query# Build graph
workflow.set_conditional_entry_point(route_question,{"web_search": "web_search","vectorstore": "retrieve",},
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges("grade_documents",decide_to_generate,{"transform_query": "transform_query","generate": "generate",},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges("generate",grade_generation_v_documents_and_question,{"not supported": "generate","useful": END,"not useful": "transform_query",},
)# Compile
app = workflow.compile()
执行
# Run
inputs = {"question": "熊队的哪位球员有望在2024年的NFL选秀中获得第一名?"
}
for output in app.stream(inputs):for key, value in output.items():# Nodepprint(f"Node '{key}':")# Optional: print full state at each node# pprint.pprint(value["keys"], indent=2, width=80, depth=None)pprint("\n---\n")