本教程将介绍如何使用LangGraph库构建和测试状态图。我们将通过一系列示例代码,逐步解释程序的运行逻辑。
1. 基本状态图构建
首先,我们定义一个状态图的基本结构和节点。
定义状态类
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from pydantic import BaseModel# 定义整个图的状态(这是在节点间共享的公共状态)
class OverallState(BaseModel):a: str
定义节点函数
def node(state: OverallState):return {"a": "goodbye"}
构建状态图
# 构建状态图
builder = StateGraph(OverallState)
builder.add_node(node) # 添加节点
builder.add_edge(START, "node") # 从起始节点开始
builder.add_edge("node", END) # 在节点后结束
graph = builder.compile()
测试状态图
# 使用有效输入测试图
graph.invoke({"a": "hello"})
输出结果:
{'a': 'goodbye'}
可视化状态图
from IPython.display import Image, displaytry:display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:# 这需要一些额外的依赖项,是可选的print("x")pass
测试无效输入
try:graph.invoke({"a": 123}) # 应该是字符串
except Exception as e:print("An exception was raised because `a` is an integer rather than a string.")print(e)
输出结果:
An exception was raised because `a` is an integer rather than a string.
1 validation error for OverallState
aInput should be a valid string [type=string_type, input_value=123, input_type=int]For further information visit https://errors.pydantic.dev/2.8/v/string_type
2. 处理无效节点
接下来,我们添加一个返回无效状态的节点,并观察其影响。
定义节点函数
def bad_node(state: OverallState):return {"a": 123 # 无效}def ok_node(state: OverallState):return {"a": "goodbye"}
构建状态图
# 构建状态图
builder = StateGraph(OverallState)
builder.add_node(bad_node)
builder.add_node(ok_node)
builder.add_edge(START, "bad_node")
builder.add_edge("bad_node", "ok_node")
builder.add_edge("ok_node", END)
graph = builder.compile()
可视化状态图
from IPython.display import Image, displaytry:display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:# 这需要一些额外的依赖项,是可选的pass
测试状态图
# 使用有效输入测试图
try:graph.invoke({"a": "hello"})
except Exception as e:print("An exception was raised because bad_node sets `a` to an integer.")print(e)
输出结果:
An exception was raised because bad_node sets `a` to an integer.
1 validation error for OverallState
aInput should be a valid string [type=string_type, input_value=123, input_type=int]For further information visit https://errors.pydantic.dev/2.8/v/string_type
3. 输入和输出状态的分离
接下来,我们定义输入和输出状态的分离。
定义状态类
from typing_extensions import TypedDict# 定义输入状态的架构
class InputState(TypedDict):question: str# 定义输出状态的架构
class OutputState(TypedDict):answer: str# 定义整体架构,结合输入和输出
class OverallState(InputState, OutputState):pass
定义节点函数
def answer_node(state: InputState):# 示例答案和一个额外的键return {"answer": "bye", "question": state["question"]}
构建状态图
# 使用指定的输入和输出架构构建图
builder = StateGraph(OverallState, input=InputState, output=OutputState)
builder.add_node(answer_node) # 添加答案节点
builder.add_edge(START, "answer_node") # 定义起始边
builder.add_edge("answer_node", END) # 定义结束边
graph1 = builder.compile() # 编译图
可视化状态图
from IPython.display import Image, displaytry:display(Image(graph1.get_graph(xray=True).draw_mermaid_png()))
except Exception:# 这需要一些额外的依赖项,是可选的pass
测试状态图
print(graph1.invoke({"question": "hi"}))
输出结果:
{'answer': 'bye'}
4. 处理私有数据
最后,我们展示如何在节点间传递私有数据。
定义状态类
# 整个图的状态(这是在节点间共享的公共状态)
class OverallState(TypedDict):a: str# node_1的输出包含私有数据,不属于整体状态
class Node1Output(TypedDict):private_data: str
定义节点函数
def node_1(state: OverallState) -> Node1Output:output = {"private_data": "set by node_1"}print(f"Entered node `node_1`:\n\tInput: {state}.\n\tReturned: {output}")return output# node_2的输入仅请求node_1后可用的私有数据
class Node2Input(TypedDict):private_data: strdef node_2(state: Node2Input) -> OverallState:output = {"a": "set by node_2"}print(f"Entered node `node_2`:\n\tInput: {state}.\n\tReturned: {output}")return output# node_3仅能访问整体状态(无法访问node_1的私有数据)
def node_3(state: OverallState) -> OverallState:output = {"a": "set by node_3"}print(f"Entered node `node_3`:\n\tInput: {state}.\n\tReturned: {output}")return output
构建状态图
# 构建状态图
builder = StateGraph(OverallState)
builder.add_node(node_1) # node_1是第一个节点
builder.add_node(node_2) # node_2是第二个节点,接受node_1的私有数据
builder.add_node(node_3) # node_3是第三个节点,无法看到私有数据
builder.add_edge(START, "node_1") # 从node_1开始
builder.add_edge("node_1", "node_2") # 从node_1传递到node_2
builder.add_edge("node_2", "node_3") # 从node_2传递到node_3(仅共享整体状态)
builder.add_edge("node_3", END) # 在node_3后结束
graph = builder.compile()
可视化状态图
from IPython.display import Image, displaytry:display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:# 这需要一些额外的依赖项,是可选的pass
测试状态图
# 使用初始状态调用图
response = graph.invoke({"a": "set at start",}
)print()
print(f"Output of graph invocation: {response}")
输出结果:
Entered node `node_1`:Input: {'a': 'set at start'}.Returned: {'private_data': 'set by node_1'}
Entered node `node_2`:Input: {'private_data': 'set by node_1'}.Returned: {'a': 'set by node_2'}
Entered node `node_3`:Input: {'a': 'set by node_2'}.Returned: {'a': 'set by node_3'}Output of graph invocation: {'a': 'set by node_3'}
通过以上步骤,我们展示了如何使用LangGraphState管理,包括处理私有数据和状态验证。希望这个教程对你有所帮助!
参考链接:https://langchain-ai.github.io/langgraph/how-tos/state-model/
https://langchain-ai.github.io/langgraph/how-tos/input_output_schema/
https://langchain-ai.github.io/langgraph/how-tos/pass_private_state/
如果有任何问题,欢迎在评论区提问。