Skip to content

Commit cabfdeb

Browse files
committed
fix: applies Responses API validation to Strands, renames Strands tests
Signed-off-by: RanjitR <[email protected]>
1 parent 5a23f2f commit cabfdeb

File tree

8 files changed

+90
-122
lines changed

8 files changed

+90
-122
lines changed

examples/frameworks/strands_demo/src/nat_strands_demo/configs/eval_config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
# path-check-skip-file
1516

1617

1718
functions:

examples/frameworks/strands_demo/src/nat_strands_demo/configs/sizing_config.yml

Lines changed: 0 additions & 44 deletions
This file was deleted.

examples/frameworks/strands_demo/src/nat_strands_demo/configs/tracing_config.yml

Lines changed: 0 additions & 53 deletions
This file was deleted.

packages/nvidia_nat_strands/src/nat/plugins/strands/llm.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nat.llm.utils.thinking import FunctionArgumentWrapper
3030
from nat.llm.utils.thinking import patch_with_thinking
3131
from nat.utils.exception_handlers.automatic_retries import patch_with_retry
32+
from nat.utils.responses_api import validate_no_responses_api
3233
from nat.utils.type_utils import override
3334

3435
ModelType = TypeVar("ModelType")
@@ -95,9 +96,11 @@ def inject(self, messages, *args, **kwargs) -> FunctionArgumentWrapper:
9596
@register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.STRANDS)
9697
async def openai_strands(llm_config: OpenAIModelConfig, _builder: Builder):
9798

99+
validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS)
100+
98101
from strands.models.openai import OpenAIModel
99102

100-
params = llm_config.model_dump(exclude={"type", "api_key", "base_url", "model_name"},
103+
params = llm_config.model_dump(exclude={"type", "api_type", "api_key", "base_url", "model_name"},
101104
by_alias=True,
102105
exclude_none=True)
103106
# Remove NAT-specific and retry-specific keys not accepted by OpenAI chat.create
@@ -119,6 +122,8 @@ async def openai_strands(llm_config: OpenAIModelConfig, _builder: Builder):
119122
@register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.STRANDS)
120123
async def nim_strands(llm_config: NIMModelConfig, _builder: Builder):
121124

125+
validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS)
126+
122127
# NIM is OpenAI compatible; use OpenAI model with NIM base_url and api_key
123128
from strands.models.openai import OpenAIModel
124129

@@ -151,7 +156,7 @@ def format_request_messages(cls, messages, system_prompt=None):
151156

152157
return formatted_messages
153158

154-
params = llm_config.model_dump(exclude={"type", "api_key", "base_url", "model_name"},
159+
params = llm_config.model_dump(exclude={"type", "api_type", "api_key", "base_url", "model_name"},
155160
by_alias=True,
156161
exclude_none=True)
157162
# Remove NAT-specific and retry-specific keys not accepted by OpenAI
@@ -173,15 +178,26 @@ def format_request_messages(cls, messages, system_prompt=None):
173178
@register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.STRANDS)
174179
async def bedrock_strands(llm_config: AWSBedrockModelConfig, _builder: Builder):
175180

181+
validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS)
182+
176183
from strands.models.bedrock import BedrockModel
177184

178-
client = BedrockModel(
179-
model_id=llm_config.model_name,
180-
max_tokens=llm_config.max_tokens,
181-
temperature=llm_config.temperature,
182-
top_p=llm_config.top_p,
183-
region_name=llm_config.region_name,
184-
endpoint_url=llm_config.base_url,
185-
)
185+
params = llm_config.model_dump(exclude={"type", "api_type", "model_name", "region_name", "base_url"},
186+
by_alias=True,
187+
exclude_none=True)
188+
189+
for k in ("max_retries",
190+
"num_retries",
191+
"retry_on_status_codes",
192+
"retry_on_errors",
193+
"thinking",
194+
"context_size",
195+
"credentials_profile_name"):
196+
params.pop(k, None)
197+
198+
client = BedrockModel(model_id=llm_config.model_name,
199+
region_name=llm_config.region_name,
200+
endpoint_url=llm_config.base_url,
201+
**params)
186202

187203
yield _patch_llm_based_on_config(client, llm_config)

packages/nvidia_nat_strands/src/nat/plugins/strands/strands_callback_handler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,12 @@ def wrapped_init(agent_self, *args, **kwargs):
276276
try:
277277
# Import hook event types
278278
# pylint: disable=import-outside-toplevel
279-
from strands.experimental.hooks import AfterToolInvocationEvent
280-
from strands.experimental.hooks import BeforeToolInvocationEvent
279+
from strands.hooks import AfterToolCallEvent
280+
from strands.hooks import BeforeToolCallEvent
281281

282282
# Register tool hooks on this agent instance
283-
agent_self.hooks.add_callback(BeforeToolInvocationEvent,
284-
handler.tool_hook.on_before_tool_invocation)
285-
agent_self.hooks.add_callback(AfterToolInvocationEvent, handler.tool_hook.on_after_tool_invocation)
283+
agent_self.hooks.add_callback(BeforeToolCallEvent, handler.tool_hook.on_before_tool_invocation)
284+
agent_self.hooks.add_callback(AfterToolCallEvent, handler.tool_hook.on_after_tool_invocation)
286285

287286
logger.debug("Strands tool hooks registered on Agent instance")
288287

packages/nvidia_nat_strands/tests/test_callback_handler.py renamed to packages/nvidia_nat_strands/tests/test_strands_callback_handler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -582,20 +582,20 @@ def __init__(self, *args, **kwargs):
582582
def import_side_effect(module_name):
583583
if "agent.agent" in module_name:
584584
return mock_agent_mod
585-
elif "experimental.hooks" in module_name:
585+
elif "hooks" in module_name and "strands" in module_name:
586586
# Import the actual hook classes for testing
587587
try:
588-
from strands.experimental.hooks import AfterToolInvocationEvent
589-
from strands.experimental.hooks import BeforeToolInvocationEvent
588+
from strands.hooks import AfterToolCallEvent
589+
from strands.hooks import BeforeToolCallEvent
590590
hook_mod = MagicMock()
591-
hook_mod.BeforeToolInvocationEvent = BeforeToolInvocationEvent
592-
hook_mod.AfterToolInvocationEvent = AfterToolInvocationEvent
591+
hook_mod.BeforeToolCallEvent = BeforeToolCallEvent
592+
hook_mod.AfterToolCallEvent = AfterToolCallEvent
593593
return hook_mod
594594
except ImportError:
595595
# Fallback to mocks if strands not available
596596
hook_mod = MagicMock()
597-
hook_mod.BeforeToolInvocationEvent = MagicMock()
598-
hook_mod.AfterToolInvocationEvent = MagicMock()
597+
hook_mod.BeforeToolCallEvent = MagicMock()
598+
hook_mod.AfterToolCallEvent = MagicMock()
599599
return hook_mod
600600
raise ImportError(f"No module named {module_name}")
601601

@@ -663,10 +663,10 @@ def __init__(self, *args, **kwargs):
663663
def import_side_effect(module_name):
664664
if "agent.agent" in module_name:
665665
return mock_agent_mod
666-
elif "experimental.hooks" in module_name:
666+
elif "hooks" in module_name and "strands" in module_name:
667667
hook_mod = MagicMock()
668-
hook_mod.BeforeToolInvocationEvent = MagicMock()
669-
hook_mod.AfterToolInvocationEvent = MagicMock()
668+
hook_mod.BeforeToolCallEvent = MagicMock()
669+
hook_mod.AfterToolCallEvent = MagicMock()
670670
return hook_mod
671671
raise ImportError(f"No module named {module_name}")
672672

packages/nvidia_nat_strands/tests/test_llm.py renamed to packages/nvidia_nat_strands/tests/test_strands_llm.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020

2121
from nat.builder.builder import Builder
22+
from nat.data_models.llm import APITypeEnum
2223
from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig
2324
from nat.llm.nim_llm import NIMModelConfig
2425
from nat.llm.openai_llm import OpenAIModelConfig
@@ -41,6 +42,11 @@ def openai_config(self):
4142
"""Create an OpenAIModelConfig instance."""
4243
return OpenAIModelConfig(model_name="gpt-4")
4344

45+
@pytest.fixture
46+
def openai_config_wrong_api(self):
47+
"""Create an OpenAIModelConfig with wrong API type."""
48+
return OpenAIModelConfig(model_name="gpt-4", api_type=APITypeEnum.RESPONSES)
49+
4450
@patch("strands.models.openai.OpenAIModel")
4551
async def test_openai_strands_basic(self, mock_model, openai_config, mock_builder):
4652
"""Test that openai_strands as async context manager."""
@@ -64,6 +70,14 @@ async def test_openai_strands_with_params(self, mock_model, openai_config, mock_
6470
async with openai_strands(openai_config, mock_builder):
6571
mock_model.assert_called_once()
6672

73+
@patch("strands.models.openai.OpenAIModel")
74+
async def test_api_type_validation(self, mock_model, openai_config_wrong_api, mock_builder):
75+
"""Non-chat-completion API types must raise a ValueError."""
76+
with pytest.raises(ValueError):
77+
async with openai_strands(openai_config_wrong_api, mock_builder):
78+
pass
79+
mock_model.assert_not_called()
80+
6781

6882
class TestBedrockStrands:
6983
"""Tests for the bedrock_strands function."""
@@ -81,6 +95,15 @@ def bedrock_config(self):
8195
region_name="us-east-1",
8296
)
8397

98+
@pytest.fixture
99+
def bedrock_config_wrong_api(self):
100+
"""Create an AWSBedrockModelConfig with wrong API type."""
101+
return AWSBedrockModelConfig(
102+
model_name="anthropic.claude-3-sonnet-20240229-v1:0",
103+
region_name="us-east-1",
104+
api_type=APITypeEnum.RESPONSES,
105+
)
106+
84107
@patch("strands.models.bedrock.BedrockModel")
85108
async def test_bedrock_strands_basic(self, mock_model, bedrock_config, mock_builder):
86109
"""Test that bedrock_strands creates a BedrockModel."""
@@ -91,6 +114,14 @@ async def test_bedrock_strands_basic(self, mock_model, bedrock_config, mock_buil
91114
async with bedrock_strands(bedrock_config, mock_builder):
92115
mock_model.assert_called_once()
93116

117+
@patch("strands.models.bedrock.BedrockModel")
118+
async def test_api_type_validation(self, mock_model, bedrock_config_wrong_api, mock_builder):
119+
"""Non-chat-completion API types must raise a ValueError."""
120+
with pytest.raises(ValueError):
121+
async with bedrock_strands(bedrock_config_wrong_api, mock_builder):
122+
pass
123+
mock_model.assert_not_called()
124+
94125

95126
class TestNIMStrands:
96127
"""Tests for the nim_strands function."""
@@ -109,6 +140,16 @@ def nim_config(self):
109140
base_url="https://integrate.api.nvidia.com/v1",
110141
)
111142

143+
@pytest.fixture
144+
def nim_config_wrong_api(self):
145+
"""Create a NIMModelConfig with wrong API type."""
146+
return NIMModelConfig(
147+
model_name="meta/llama-3.1-8b-instruct",
148+
api_key="test-api-key",
149+
base_url="https://integrate.api.nvidia.com/v1",
150+
api_type=APITypeEnum.RESPONSES,
151+
)
152+
112153
async def test_nim_strands_basic(self, nim_config, mock_builder):
113154
"""Test that nim_strands creates a NIMCompatibleOpenAIModel."""
114155
# Patch OpenAIModel.__init__ to track the call
@@ -220,6 +261,14 @@ async def test_nim_strands_excludes_nat_specific_params(self, mock_builder):
220261
assert "thinking" not in params
221262
assert "retry_on_status_codes" not in params
222263

264+
async def test_api_type_validation(self, nim_config_wrong_api, mock_builder):
265+
"""Non-chat-completion API types must raise a ValueError."""
266+
with patch("strands.models.openai.OpenAIModel.__init__", return_value=None) as mock_init:
267+
with pytest.raises(ValueError):
268+
async with nim_strands(nim_config_wrong_api, mock_builder):
269+
pass
270+
mock_init.assert_not_called()
271+
223272

224273
class TestPatchLLMBasedOnConfig:
225274
"""Tests for _patch_llm_based_on_config function."""

0 commit comments

Comments
 (0)