目录
- 1. 原始代码
- 2. 代码测试
- 3. 代码的运行逻辑
- 4. UserProxyAgent 类的核心功能
- 5. UserProxyAgent 类的使用
- 6. 运行时流程
- 7. 总结
1. 原始代码
import asyncio
from inspect import iscoroutinefunction
from typing import Awaitable, Callable, List, Optional, Sequence, Union, castfrom autogen_core.base import CancellationTokenfrom ..base import Response
from ..messages import ChatMessage, HandoffMessage, TextMessage
from ._base_chat_agent import BaseChatAgent# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]class UserProxyAgent(BaseChatAgent):"""An agent that can represent a human user through an input function.This agent can be used to represent a human user in a chat system by providing a custom input function.Args:name (str): The name of the agent.description (str, optional): A description of the agent.input_func (Optional[Callable[[str], str]], Callable[[str, Optional[CancellationToken]], Awaitable[str]]): A function that takes a prompt and returns a user input string... note::Using :class:`UserProxyAgent` puts a running team in a temporary blockedstate until the user responds. So it is important to time out the user inputfunction and cancel using the :class:`~autogen_core.base.CancellationToken` if the user does not respond.The input function should also handle exceptions and return a default response if needed.For typical use cases that involveslow human responses, it is recommended to use termination conditionssuch as :class:`~autogen_agentchat.task.HandoffTermination` or :class:`~autogen_agentchat.task.SourceMatchTermination`to stop the running team and return the control to the application.You can run the team again with the user input. This way, the state of the teamcan be saved and restored when the user responds.See `Pause for User Input <https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/teams.html#pause-for-user-input>`_ for more information."""def __init__(self,name: str,*,description: str = "A human user",input_func: Optional[InputFuncType] = None,) -> None:"""Initialize the UserProxyAgent."""super().__init__(name=name, description=description)self.input_func = input_func or inputself._is_async = iscoroutinefunction(self.input_func)@propertydef produced_message_types(self) -> List[type[ChatMessage]]:"""Message types this agent can produce."""return [TextMessage, HandoffMessage]def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:"""Find the HandoffMessage in the message sequence that addresses this agent."""if len(messages) > 0 and isinstance(messages[-1], HandoffMessage):if messages[-1].target == self.name:return messages[-1]else:raise RuntimeError(f"Handoff message target does not match agent name: {messages[-1].source}")return Noneasync def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:"""Handle input based on function signature."""try:if self._is_async:# Cast to AsyncInputFunc for proper typingasync_func = cast(AsyncInputFunc, self.input_func)return await async_func(prompt, cancellation_token)else:# Cast to SyncInputFunc for proper typingsync_func = cast(SyncInputFunc, self.input_func)loop = asyncio.get_event_loop()return await loop.run_in_executor(None, sync_func, prompt)except asyncio.CancelledError:raiseexcept Exception as e:raise RuntimeError(f"Failed to get user input: {str(e)}") from easync def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None) -> Response:"""Handle incoming messages by requesting user input."""try:# Check for handoff firsthandoff = self._get_latest_handoff(messages)prompt = (f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: ")# print(prompt)user_input = await self._get_input(prompt, cancellation_token)# print(user_input)# Return appropriate message type based on handoff presenceif handoff:return Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))else:return Response(chat_message=TextMessage(content=user_input, source=self.name))except asyncio.CancelledError:raiseexcept Exception as e:raise RuntimeError(f"Failed to get user input: {str(e)}") from easync def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:"""Reset agent state."""pass
2. 代码测试
import asyncio
from typing import Optional, Sequenceimport pytest
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage
from autogen_core.base import CancellationTokendef custom_input(prompt: str) -> str:return "The height of the eiffel tower is 324 meters. Aloha!"agent = UserProxyAgent(name="test_user", input_func=custom_input)
messages = [TextMessage(content="What is the height of the eiffel tower?", source="assistant")
]response = asyncio.run(agent.on_messages(messages, CancellationToken()))print(response)
运行结果
Response(chat_message=TextMessage(source='test_user', models_usage=None, content='The height of the eiffel tower is 324 meters. Aloha!'), inner_messages=None)
3. 代码的运行逻辑
这段代码主要分为两部分:
- UserProxyAgent 类的定义:这是一个模拟人类用户交互的代理类,接收用户输入并根据
input_func
(输入函数)来处理输入。 - 使用 UserProxyAgent 的示例:通过创建 UserProxyAgent 实例并调用它的
on_messages
方法来模拟一个简单的交互过程。
我们一步步分析这段代码的运行逻辑:
4. UserProxyAgent 类的核心功能
-
__init__
方法:def __init__(self, name: str, *, description: str = "A human user", input_func: Optional[InputFuncType] = None) -> None:super().__init__(name=name, description=description)self.input_func = input_func or input # 默认使用内置 input 函数self._is_async = iscoroutinefunction(self.input_func) # 检查 input_func 是否是异步函数
input_func
:用于获取用户输入的函数。可以是同步或异步函数。self._is_async
:判断input_func
是同步的还是异步的。如果input_func
是异步函数(例如async def
),则self._is_async
为True
。
-
produced_message_types
属性:@property def produced_message_types(self) -> List[type[ChatMessage]]:return [TextMessage, HandoffMessage]
这个属性定义了
UserProxyAgent
可以产生的消息类型。这里返回了TextMessage
和HandoffMessage
,这两种类型的消息可以由代理生成。 -
_get_input
方法:async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:try:if self._is_async:async_func = cast(AsyncInputFunc, self.input_func)return await async_func(prompt, cancellation_token)else:sync_func = cast(SyncInputFunc, self.input_func)loop = asyncio.get_event_loop()return await loop.run_in_executor(None, sync_func, prompt)except Exception as e:raise RuntimeError(f"Failed to get user input: {str(e)}")
_get_input
方法是根据input_func
的类型(异步或同步)来获取用户输入:- 如果是异步函数,直接调用
await async_func(prompt, cancellation_token)
。 - 如果是同步函数,使用
run_in_executor
将同步函数放入线程池中异步执行。
- 如果是异步函数,直接调用
-
on_messages
方法:async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None) -> Response:try:# Check for handoff firsthandoff = self._get_latest_handoff(messages)prompt = (f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: ")user_input = await self._get_input(prompt, cancellation_token)if handoff:return Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))else:return Response(chat_message=TextMessage(content=user_input, source=self.name))except Exception as e:raise RuntimeError(f"Failed to get user input: {str(e)}")
on_messages
方法是处理接收到的消息并根据input_func
请求用户输入的核心方法。- 它首先检查是否有
handoff
消息,handoff
是一种消息转交机制,用于指示代理需要等待用户输入。 - 然后,根据是否有
handoff
,设置prompt
提示用户输入,并调用_get_input
获取用户输入。 - 最后,根据输入生成
TextMessage
或HandoffMessage
,并通过Response
返回。
- 它首先检查是否有
5. UserProxyAgent 类的使用
-
定义一个同步输入函数
custom_input
:def custom_input(prompt: str) -> str:return "The height of the Eiffel Tower is 324 meters. Aloha!"
这是一个简单的同步函数,模拟用户输入。无论传入什么
prompt
,它总是返回 “The height of the Eiffel Tower is 324 meters. Aloha!”。 -
创建 UserProxyAgent 实例:
agent = UserProxyAgent(name="test_user", input_func=custom_input)
创建一个
UserProxyAgent
实例,name
设置为 “test_user”,并将input_func
设置为刚定义的custom_input
函数。 -
创建消息列表
messages
:messages = [TextMessage(content="What is the height of the Eiffel Tower?", source="assistant") ]
这是一个消息列表,包含一条来自 “assistant” 的
TextMessage
,询问 “What is the height of the Eiffel Tower?”。 -
调用
on_messages
方法:response = asyncio.run(agent.on_messages(messages, CancellationToken()))
这里使用
asyncio.run()
来运行异步的on_messages
方法。
messages
传递给on_messages
,模拟一个对话。代理通过custom_input
获取用户输入。 -
输出响应:
print(response)
最终输出的
response
是on_messages
返回的结果。根据输入的不同,代理可能返回TextMessage
或HandoffMessage
。
6. 运行时流程
on_messages
被调用,代理开始等待用户输入。prompt
被设置为 “Enter your response:”,并调用custom_input
获取用户输入。- 用户输入被模拟为 “The height of the Eiffel Tower is 324 meters. Aloha!”。
- 生成并返回一个
TextMessage
,其内容为用户输入的文本。 - 最终,
response
将是一个包含TextMessage
的Response
对象。
7. 总结
UserProxyAgent
是一个模拟用户交互的代理,能够根据不同类型的输入函数(同步或异步)来获取用户的回答。
on_messages
方法负责处理来自其他代理或系统的消息,等待并获取用户输入,生成适当的消息类型并返回。
示例代码模拟了一个简单的对话,其中用户输入的文本通过 UserProxyAgent
处理并返回。
参考链接:https://github.com/microsoft/autogen/blob/main/python/packages/autogen-agentchat/tests/test_userproxy_agent.py
如果有任何问题,欢迎在评论区提问。