Skip to content

Commit a844794

Browse files
authored
Catch langchain agent errors (nv-morpheus#1539)
* Catch any uncaught/un-handled langchain agent errors, return an error string so that we don't end up with a missing result. * Fix version conflicts for nemollm and dgl * Expand existing tests for the `LangChainAgentNode` ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Christopher Harris (https://github.com/cwharris) - Michael Demoret (https://github.com/mdemoret-nv) URL: nv-morpheus#1539
1 parent 744ba79 commit a844794

File tree

8 files changed

+139
-8
lines changed

8 files changed

+139
-8
lines changed

conda/environments/all_cuda-121_arch-x86_64.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ dependencies:
9393
- pytorch=*=*cuda*
9494
- rapidjson=1.1.0
9595
- rdma-core>=48
96+
- requests
9697
- requests-cache=1.1
98+
- requests-toolbelt
9799
- s3fs=2023.12.2
98100
- scikit-build=0.17.6
99101
- scikit-learn=1.3.2
@@ -117,7 +119,7 @@ dependencies:
117119
- --find-links https://data.dgl.ai/wheels/cu121/repo.html
118120
- PyMuPDF==1.23.21
119121
- databricks-connect
120-
- dgl
122+
- dgl==2.0.0
121123
- dglgo
122124
- google-search-results==2.4
123125
- langchain==0.1.9

conda/environments/dev_cuda-121_arch-x86_64.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ dependencies:
7373
- pytorch-cuda
7474
- pytorch=*=*cuda*
7575
- rapidjson=1.1.0
76+
- requests
7677
- requests-cache=1.1
78+
- requests-toolbelt
7779
- scikit-build=0.17.6
7880
- scikit-learn=1.3.2
7981
- sphinx

conda/environments/examples_cuda-121_arch-x86_64.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ dependencies:
4646
- python=3.10
4747
- pytorch-cuda
4848
- pytorch=*=*cuda*
49+
- requests
4950
- requests-cache=1.1
51+
- requests-toolbelt
5052
- s3fs=2023.12.2
5153
- scikit-learn=1.3.2
5254
- sentence-transformers
@@ -61,7 +63,7 @@ dependencies:
6163
- --find-links https://data.dgl.ai/wheels/cu121/repo.html
6264
- PyMuPDF==1.23.21
6365
- databricks-connect
64-
- dgl
66+
- dgl==2.0.0
6567
- dglgo
6668
- google-search-results==2.4
6769
- langchain==0.1.9

conda/environments/runtime_cuda-121_arch-x86_64.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ dependencies:
2828
- python=3.10
2929
- pytorch-cuda
3030
- pytorch=*=*cuda*
31+
- requests
3132
- requests-cache=1.1
33+
- requests-toolbelt
3234
- scikit-learn=1.3.2
3335
- sqlalchemy
3436
- tqdm=4

dependencies.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ dependencies:
261261
- python-graphviz
262262
- pytorch-cuda
263263
- pytorch=*=*cuda*
264+
- requests
264265
- requests-cache=1.1
266+
- requests-toolbelt # Transitive dep needed by nemollm, specified here to ensure we get a compatible version
265267
- sqlalchemy
266268
- tqdm=4
267269
- typing_utils=0.1
@@ -311,7 +313,7 @@ dependencies:
311313
- pip:
312314
- --find-links https://data.dgl.ai/wheels/cu121/repo.html
313315
- --find-links https://data.dgl.ai/wheels-test/repo.html
314-
- dgl
316+
- dgl==2.0.0
315317
- dglgo
316318

317319
example-llm-agents:

morpheus/llm/nodes/langchain_agent_node.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,14 @@ async def _run_single(self, **kwargs: dict[str, typing.Any]) -> dict[str, typing
6666
return results
6767

6868
# We are not dealing with a list, so run single
69-
return await self._agent_executor.arun(**kwargs)
70-
71-
async def execute(self, context: LLMContext) -> LLMContext:
69+
try:
70+
return await self._agent_executor.arun(**kwargs)
71+
except Exception as e:
72+
error_msg = f"Error running agent: {e}"
73+
logger.exception(error_msg)
74+
return error_msg
75+
76+
async def execute(self, context: LLMContext) -> LLMContext: # pylint: disable=invalid-overridden-method
7277

7378
input_dict = context.get_inputs()
7479

tests/_utils/llm.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,35 @@ def mk_mock_openai_response(messages: list[str]) -> mock.MagicMock:
8787
Creates a mocked openai.types.chat.chat_completion.ChatCompletion response with the given messages.
8888
"""
8989
response = mock.MagicMock()
90-
mock_choices = [_mk_mock_choice(message) for message in messages]
91-
response.choices = mock_choices
90+
91+
response.choices = [_mk_mock_choice(message) for message in messages]
92+
response.dict.return_value = {
93+
"choices": [{
94+
'message': {
95+
'role': 'assistant', 'content': message
96+
}
97+
} for message in messages]
98+
}
9299

93100
return response
101+
102+
103+
def mk_mock_langchain_tool(responses: list[str]) -> mock.MagicMock:
104+
"""
105+
Creates a mocked LangChainTestTool with the given responses.
106+
"""
107+
108+
# Langchain will call inspect.signature on the tool methods, typically mock objects don't have a signature,
109+
# explicitly providing one here
110+
async def _arun_spec(*_, **__):
111+
pass
112+
113+
def run_spec(*_, **__):
114+
pass
115+
116+
tool = mock.MagicMock()
117+
tool.arun = mock.create_autospec(spec=_arun_spec)
118+
tool.arun.side_effect = responses
119+
tool.run = mock.create_autospec(run_spec)
120+
tool.run.side_effect = responses
121+
return tool

tests/llm/nodes/test_langchain_agent_node.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@
1616
from unittest import mock
1717

1818
import pytest
19+
from langchain.agents import AgentType
20+
from langchain.agents import Tool
21+
from langchain.agents import initialize_agent
22+
from langchain.chat_models import ChatOpenAI # pylint: disable=no-name-in-module
1923

2024
from _utils.llm import execute_node
25+
from _utils.llm import mk_mock_langchain_tool
26+
from _utils.llm import mk_mock_openai_response
2127
from morpheus.llm import LLMNodeBase
2228
from morpheus.llm.nodes.langchain_agent_node import LangChainAgentNode
2329

@@ -50,8 +56,90 @@ def test_execute(
5056
expected_output: list,
5157
expected_calls: list[mock.call],
5258
):
59+
# Tests the execute method of the LangChainAgentNode with a mocked agent_executor
5360
mock_agent_executor.arun.return_value = arun_return
5461

5562
node = LangChainAgentNode(agent_executor=mock_agent_executor)
5663
assert execute_node(node, **values) == expected_output
5764
mock_agent_executor.arun.assert_has_calls(expected_calls)
65+
66+
67+
def test_execute_tools(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock]):
68+
# Tests the execute method of the LangChainAgentNode with a a mocked tools and chat completion
69+
(_, mock_async_client) = mock_chat_completion
70+
chat_responses = [
71+
'I should check Tool1\nAction: Tool1\nAction Input: "name a reptile"',
72+
'I should check Tool2\nAction: Tool2\nAction Input: "name of a day of the week"',
73+
'I should check Tool1\nAction: Tool1\nAction Input: "name a reptile"',
74+
'I should check Tool2\nAction: Tool2\nAction Input: "name of a day of the week"',
75+
'Observation: Answer: Yes!\nI now know the final answer.\nFinal Answer: Yes!'
76+
]
77+
mock_responses = [mk_mock_openai_response([response]) for response in chat_responses]
78+
mock_async_client.chat.completions.create.side_effect = mock_responses
79+
80+
llm_chat = ChatOpenAI(model="fake-model", openai_api_key="fake-key")
81+
82+
mock_tool1 = mk_mock_langchain_tool(["lizard", "frog"])
83+
mock_tool2 = mk_mock_langchain_tool(["Tuesday", "Thursday"])
84+
85+
tools = [
86+
Tool(name="Tool1",
87+
func=mock_tool1.run,
88+
coroutine=mock_tool1.arun,
89+
description="useful for when you need to know the name of a reptile"),
90+
Tool(name="Tool2",
91+
func=mock_tool2.run,
92+
coroutine=mock_tool2.arun,
93+
description="useful for when you need to know the day of the week")
94+
]
95+
96+
agent = initialize_agent(tools,
97+
llm_chat,
98+
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
99+
verbose=True,
100+
handle_parsing_errors=True,
101+
early_stopping_method="generate",
102+
return_intermediate_steps=False)
103+
104+
node = LangChainAgentNode(agent_executor=agent)
105+
106+
assert execute_node(node, input="input1") == "Yes!"
107+
108+
109+
def test_execute_error(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock]):
110+
# Tests the execute method of the LangChainAgentNode with a a mocked tools and chat completion
111+
(_, mock_async_client) = mock_chat_completion
112+
chat_responses = [
113+
'I should check Tool1\nAction: Tool1\nAction Input: "name a reptile"',
114+
'I should check Tool2\nAction: Tool2\nAction Input: "name of a day of the week"',
115+
'Observation: Answer: Yes!\nI now know the final answer.\nFinal Answer: Yes!'
116+
]
117+
mock_responses = [mk_mock_openai_response([response]) for response in chat_responses]
118+
mock_async_client.chat.completions.create.side_effect = mock_responses
119+
120+
llm_chat = ChatOpenAI(model="fake-model", openai_api_key="fake-key")
121+
122+
mock_tool1 = mk_mock_langchain_tool(["lizard"])
123+
mock_tool2 = mk_mock_langchain_tool(RuntimeError("unittest"))
124+
125+
tools = [
126+
Tool(name="Tool1",
127+
func=mock_tool1.run,
128+
coroutine=mock_tool1.arun,
129+
description="useful for when you need to know the name of a reptile"),
130+
Tool(name="Tool2",
131+
func=mock_tool2.run,
132+
coroutine=mock_tool2.arun,
133+
description="useful for when you need to test tool errors")
134+
]
135+
136+
agent = initialize_agent(tools,
137+
llm_chat,
138+
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
139+
verbose=True,
140+
handle_parsing_errors=True,
141+
early_stopping_method="generate",
142+
return_intermediate_steps=False)
143+
144+
node = LangChainAgentNode(agent_executor=agent)
145+
assert execute_node(node, input="input1") == "Error running agent: unittest"

0 commit comments

Comments
 (0)