diff --git a/example/agent/tool_agent_usage.py b/example/agent/tool_agent_usage.py index 19e802ad..66352945 100644 --- a/example/agent/tool_agent_usage.py +++ b/example/agent/tool_agent_usage.py @@ -10,7 +10,8 @@ def main(): model = pne.LLMFactory.build(model_name="gpt-4-1106-preview") agent = pne.ToolAgent(tools=tools, llm=model) prompt = """Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?""" # noqa - agent.run(prompt) + for response in agent.run(prompt, stream=True): + print(response) if __name__ == "__main__": diff --git a/promptulate/agents/base.py b/promptulate/agents/base.py index eafeda2b..7cd4c702 100644 --- a/promptulate/agents/base.py +++ b/promptulate/agents/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Generator from promptulate.hook import Hook, HookTable from promptulate.llms import BaseLLM @@ -25,6 +25,7 @@ def run( instruction: str, output_schema: Optional[type(BaseModel)] = None, examples: Optional[List[BaseModel]] = None, + stream: bool = False, *args, **kwargs, ) -> Any: @@ -39,6 +40,9 @@ def run( **kwargs, ) + if stream: + return self._run_stream(instruction, output_schema, examples, *args, **kwargs) + # get original response from LLM result: str = self._run(instruction, *args, **kwargs) @@ -60,6 +64,45 @@ def run( ) return result + def _run_stream( + self, + instruction: str, + output_schema: Optional[type(BaseModel)] = None, + examples: Optional[List[BaseModel]] = None, + *args, + **kwargs, + ) -> Generator[Any, None, None]: + """Run the tool including specified function and hooks with streaming output""" + Hook.call_hook( + HookTable.ON_AGENT_START, + self, + instruction, + output_schema, + *args, + agent_type=self._agent_type, + **kwargs, + ) + + for result in self._run(instruction, *args, **kwargs): + # TODOļ¼š need to optimize + # if output_schema: + # formatter = OutputFormatter(output_schema, examples) + # prompt = ( + # f"{formatter.get_formatted_instructions()}\n##User input:\n{result}" + # ) + # json_response: str = self.get_llm()(prompt) + # yield formatter.formatting_result(json_response) + # else: + yield result + + Hook.call_hook( + HookTable.ON_AGENT_RESULT, + mounted_obj=self, + result=result, + agent_type=self._agent_type, + _from=self._from, + ) + @abstractmethod def _run(self, instruction: str, *args, **kwargs) -> str: """Run the detail agent, implemented by subclass.""" diff --git a/promptulate/agents/tool_agent/agent.py b/promptulate/agents/tool_agent/agent.py index b15b5cba..aec8d8dd 100644 --- a/promptulate/agents/tool_agent/agent.py +++ b/promptulate/agents/tool_agent/agent.py @@ -119,17 +119,21 @@ def current_date(self) -> str: return f"Current date: {time.strftime('%Y-%m-%d %H:%M:%S')}" def _run( - self, instruction: str, return_raw_data: bool = False, **kwargs + self, instruction: str, return_raw_data: bool = False, stream: bool = False, **kwargs ) -> Union[str, ActionResponse]: """Run the tool agent. The tool agent will interact with the LLM and the tool. Args: instruction(str): The instruction to the tool agent. return_raw_data(bool): Whether to return raw data. Default is False. + stream(bool): Whether to enable streaming output. Default is False. Returns: The output of the tool agent. """ + if stream and "output_schema" not in kwargs: + raise ValueError("output_schema must be provided when stream=True") + self.conversation_prompt = self._build_system_prompt(instruction) logger.info(f"[pne] ToolAgent system prompt: {self.conversation_prompt}") @@ -174,6 +178,9 @@ def _run( ) self.conversation_prompt += f"Observation: {tool_result}\n" + if stream: + yield tool_result + iterations += 1 used_time += time.time() - start_time diff --git a/tests/agents/test_tool_agent.py b/tests/agents/test_tool_agent.py index 7b35230b..7fcf0524 100644 --- a/tests/agents/test_tool_agent.py +++ b/tests/agents/test_tool_agent.py @@ -1,55 +1,64 @@ -from promptulate.agents.tool_agent.agent import ToolAgent -from promptulate.llms.base import BaseLLM -from promptulate.tools.base import BaseToolKit - - -class FakeLLM(BaseLLM): - def _predict(self, prompts, *args, **kwargs): - pass - - def __call__(self, *args, **kwargs): - return """## Output - ```json - { - "city": "Shanghai", - "temperature": 25 - } - ```""" - - -def fake_tool_1(): - """Fake tool 1""" - return "Fake tool 1" - - -def fake_tool_2(): - """Fake tool 2""" - return "Fake tool 2" - - -def test_init(): - llm = FakeLLM() - agent = ToolAgent(llm=llm) - assert len(agent.tool_manager.tools) == 0 - - agent = ToolAgent(llm=llm, tools=[fake_tool_1, fake_tool_2]) - assert len(agent.tool_manager.tools) == 2 - assert agent.tool_manager.tools[0].name == "fake_tool_1" - assert agent.tool_manager.tools[1].name == "fake_tool_2" - - -class MockToolKit(BaseToolKit): - def get_tools(self) -> list: - return [fake_tool_1, fake_tool_2] - - -def test_init_by_toolkits(): - llm = FakeLLM() - agent = ToolAgent(llm=llm, tools=[MockToolKit()]) - assert len(agent.tool_manager.tools) == 2 - - -def test_init_by_tool_and_kit(): - llm = FakeLLM() - agent = ToolAgent(llm=llm, tools=[MockToolKit(), fake_tool_1, fake_tool_2]) - assert len(agent.tool_manager.tools) == 4 +from promptulate.agents.tool_agent.agent import ToolAgent +from promptulate.llms.base import BaseLLM +from promptulate.tools.base import BaseToolKit + + +class FakeLLM(BaseLLM): + def _predict(self, prompts, *args, **kwargs): + pass + + def __call__(self, *args, **kwargs): + return """## Output + ```json + { + "city": "Shanghai", + "temperature": 25 + } + ```""" + + +def fake_tool_1(): + """Fake tool 1""" + return "Fake tool 1" + + +def fake_tool_2(): + """Fake tool 2""" + return "Fake tool 2" + + +def test_init(): + llm = FakeLLM() + agent = ToolAgent(llm=llm) + assert len(agent.tool_manager.tools) == 0 + + agent = ToolAgent(llm=llm, tools=[fake_tool_1, fake_tool_2]) + assert len(agent.tool_manager.tools) == 2 + assert agent.tool_manager.tools[0].name == "fake_tool_1" + assert agent.tool_manager.tools[1].name == "fake_tool_2" + + +class MockToolKit(BaseToolKit): + def get_tools(self) -> list: + return [fake_tool_1, fake_tool_2] + + +def test_init_by_toolkits(): + llm = FakeLLM() + agent = ToolAgent(llm=llm, tools=[MockToolKit()]) + assert len(agent.tool_manager.tools) == 2 + + +def test_init_by_tool_and_kit(): + llm = FakeLLM() + agent = ToolAgent(llm=llm, tools=[MockToolKit(), fake_tool_1, fake_tool_2]) + assert len(agent.tool_manager.tools) == 4 + + +def test_stream_mode(): + llm = FakeLLM() + agent = ToolAgent(llm=llm, tools=[fake_tool_1, fake_tool_2]) + prompt = "What is the temperature in Shanghai?" + responses = list(agent.run(prompt, stream=True)) + assert len(responses) > 0 + assert all(isinstance(response, str) for response in responses)