-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Related to #833 Add stream output support for ToolAgent. * **ToolAgent Class**: - Add `stream` parameter to the `run` method. - Implement logic to handle the `stream` parameter. - Modify the `_run` method to support streaming output. - Raise error if `stream=True` and `output_schema` is not provided. * **BaseAgent Class**: - Add `stream` parameter to the `run` method. - Implement `_run_stream` method to handle streaming output. * **Example**: - Update `example/agent/tool_agent_usage.py` to demonstrate the usage of `agent.run(..., stream=True)`. * **Tests**: - Add tests in `tests/agents/test_tool_agent.py` to verify the stream mode output functionality. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/Undertone0809/promptulate/issues/833?shareId=XXXX-XXXX-XXXX-XXXX).
- Loading branch information
1 parent
11b21d8
commit a92d2e3
Showing
4 changed files
with
118 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,64 @@ | ||
from promptulate.agents.tool_agent.agent import ToolAgent | ||
from promptulate.llms.base import BaseLLM | ||
from promptulate.tools.base import BaseToolKit | ||
|
||
|
||
class FakeLLM(BaseLLM): | ||
def _predict(self, prompts, *args, **kwargs): | ||
pass | ||
|
||
def __call__(self, *args, **kwargs): | ||
return """## Output | ||
```json | ||
{ | ||
"city": "Shanghai", | ||
"temperature": 25 | ||
} | ||
```""" | ||
|
||
|
||
def fake_tool_1(): | ||
"""Fake tool 1""" | ||
return "Fake tool 1" | ||
|
||
|
||
def fake_tool_2(): | ||
"""Fake tool 2""" | ||
return "Fake tool 2" | ||
|
||
|
||
def test_init(): | ||
llm = FakeLLM() | ||
agent = ToolAgent(llm=llm) | ||
assert len(agent.tool_manager.tools) == 0 | ||
|
||
agent = ToolAgent(llm=llm, tools=[fake_tool_1, fake_tool_2]) | ||
assert len(agent.tool_manager.tools) == 2 | ||
assert agent.tool_manager.tools[0].name == "fake_tool_1" | ||
assert agent.tool_manager.tools[1].name == "fake_tool_2" | ||
|
||
|
||
class MockToolKit(BaseToolKit): | ||
def get_tools(self) -> list: | ||
return [fake_tool_1, fake_tool_2] | ||
|
||
|
||
def test_init_by_toolkits(): | ||
llm = FakeLLM() | ||
agent = ToolAgent(llm=llm, tools=[MockToolKit()]) | ||
assert len(agent.tool_manager.tools) == 2 | ||
|
||
|
||
def test_init_by_tool_and_kit(): | ||
llm = FakeLLM() | ||
agent = ToolAgent(llm=llm, tools=[MockToolKit(), fake_tool_1, fake_tool_2]) | ||
assert len(agent.tool_manager.tools) == 4 | ||
from promptulate.agents.tool_agent.agent import ToolAgent | ||
from promptulate.llms.base import BaseLLM | ||
from promptulate.tools.base import BaseToolKit | ||
|
||
|
||
class FakeLLM(BaseLLM): | ||
def _predict(self, prompts, *args, **kwargs): | ||
pass | ||
|
||
def __call__(self, *args, **kwargs): | ||
return """## Output | ||
```json | ||
{ | ||
"city": "Shanghai", | ||
"temperature": 25 | ||
} | ||
```""" | ||
|
||
|
||
def fake_tool_1(): | ||
"""Fake tool 1""" | ||
return "Fake tool 1" | ||
|
||
|
||
def fake_tool_2(): | ||
"""Fake tool 2""" | ||
return "Fake tool 2" | ||
|
||
|
||
def test_init(): | ||
llm = FakeLLM() | ||
agent = ToolAgent(llm=llm) | ||
assert len(agent.tool_manager.tools) == 0 | ||
|
||
agent = ToolAgent(llm=llm, tools=[fake_tool_1, fake_tool_2]) | ||
assert len(agent.tool_manager.tools) == 2 | ||
assert agent.tool_manager.tools[0].name == "fake_tool_1" | ||
assert agent.tool_manager.tools[1].name == "fake_tool_2" | ||
|
||
|
||
class MockToolKit(BaseToolKit): | ||
def get_tools(self) -> list: | ||
return [fake_tool_1, fake_tool_2] | ||
|
||
|
||
def test_init_by_toolkits(): | ||
llm = FakeLLM() | ||
agent = ToolAgent(llm=llm, tools=[MockToolKit()]) | ||
assert len(agent.tool_manager.tools) == 2 | ||
|
||
|
||
def test_init_by_tool_and_kit(): | ||
llm = FakeLLM() | ||
agent = ToolAgent(llm=llm, tools=[MockToolKit(), fake_tool_1, fake_tool_2]) | ||
assert len(agent.tool_manager.tools) == 4 | ||
|
||
|
||
def test_stream_mode(): | ||
llm = FakeLLM() | ||
agent = ToolAgent(llm=llm, tools=[fake_tool_1, fake_tool_2]) | ||
prompt = "What is the temperature in Shanghai?" | ||
responses = list(agent.run(prompt, stream=True)) | ||
assert len(responses) > 0 | ||
assert all(isinstance(response, str) for response in responses) |