Skip to content

Commit 44940db

Browse files
author
Tirthanu Ghosh
committed
Add hybrid search support with overrideSearchType parameter
- Add overrideSearchType parameter to tool input schema with HYBRID/SEMANTIC enum validation - Implement parameter validation and configuration in retrieve function - Add comprehensive tests for hybrid search functionality including error cases - Support both HYBRID and SEMANTIC search types as per AWS Bedrock API documentation - Fix line length violations for code quality compliance - Preserve all existing tests including test_retrieve_via_agent_with_enable_metadata
1 parent 6322306 commit 44940db

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

src/strands_tools/retrieve.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@
166166
),
167167
"default": False,
168168
},
169+
"overrideSearchType": {
170+
"type": "string",
171+
"description": (
172+
"Override the search type for the knowledge base query. Supported values: 'HYBRID', "
173+
"'SEMANTIC'. Default behavior uses the knowledge base's configured search type."
174+
),
175+
"enum": ["HYBRID", "SEMANTIC"],
176+
},
169177
},
170178
"required": ["text"],
171179
}
@@ -306,6 +314,16 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult:
306314
min_score = tool_input.get("score", default_min_score)
307315
enable_metadata = tool_input.get("enableMetadata", default_enable_metadata)
308316
retrieve_filter = tool_input.get("retrieveFilter")
317+
override_search_type = tool_input.get("overrideSearchType")
318+
319+
# Validate overrideSearchType if provided
320+
if override_search_type and override_search_type not in ["HYBRID", "SEMANTIC"]:
321+
return {
322+
"toolUseId": tool_use_id,
323+
"status": "error",
324+
"content": [{"text": f"Invalid overrideSearchType: {override_search_type}. "
325+
f"Supported values: HYBRID, SEMANTIC"}],
326+
}
309327

310328
# Initialize Bedrock client with optional profile name
311329
profile_name = tool_input.get("profile_name")
@@ -321,6 +339,9 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult:
321339
# Default retrieval configuration
322340
retrieval_config = {"vectorSearchConfiguration": {"numberOfResults": number_of_results}}
323341

342+
if override_search_type:
343+
retrieval_config["vectorSearchConfiguration"]["overrideSearchType"] = override_search_type
344+
324345
if retrieve_filter:
325346
try:
326347
if _validate_filter(retrieve_filter):

tests/test_retrieve.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,171 @@ def test_retrieve_with_environment_variable_default(mock_boto3_client):
656656
assert "test-source-1" not in result_text
657657

658658

659+
def test_retrieve_with_override_search_type_hybrid(mock_boto3_client):
660+
"""Test retrieve with overrideSearchType set to HYBRID."""
661+
tool_use = {
662+
"toolUseId": "test-tool-use-id",
663+
"input": {
664+
"text": "test query",
665+
"knowledgeBaseId": "test-kb-id",
666+
"overrideSearchType": "HYBRID",
667+
},
668+
}
669+
670+
result = retrieve.retrieve(tool=tool_use)
671+
672+
# Verify the result is successful
673+
assert result["status"] == "success"
674+
assert "Retrieved 2 results with score >= 0.4" in result["content"][0]["text"]
675+
676+
# Verify that boto3 client was called with overrideSearchType
677+
mock_boto3_client.return_value.retrieve.assert_called_once_with(
678+
retrievalQuery={"text": "test query"},
679+
knowledgeBaseId="test-kb-id",
680+
retrievalConfiguration={
681+
"vectorSearchConfiguration": {
682+
"numberOfResults": 10,
683+
"overrideSearchType": "HYBRID"
684+
}
685+
},
686+
)
687+
688+
689+
def test_retrieve_with_override_search_type_semantic(mock_boto3_client):
690+
"""Test retrieve with overrideSearchType set to SEMANTIC."""
691+
tool_use = {
692+
"toolUseId": "test-tool-use-id",
693+
"input": {
694+
"text": "test query",
695+
"knowledgeBaseId": "test-kb-id",
696+
"overrideSearchType": "SEMANTIC",
697+
},
698+
}
699+
700+
result = retrieve.retrieve(tool=tool_use)
701+
702+
# Verify the result is successful
703+
assert result["status"] == "success"
704+
705+
# Verify that boto3 client was called with overrideSearchType
706+
mock_boto3_client.return_value.retrieve.assert_called_once_with(
707+
retrievalQuery={"text": "test query"},
708+
knowledgeBaseId="test-kb-id",
709+
retrievalConfiguration={
710+
"vectorSearchConfiguration": {
711+
"numberOfResults": 10,
712+
"overrideSearchType": "SEMANTIC"
713+
}
714+
},
715+
)
716+
717+
718+
def test_retrieve_with_invalid_override_search_type(mock_boto3_client):
719+
"""Test retrieve with invalid overrideSearchType."""
720+
tool_use = {
721+
"toolUseId": "test-tool-use-id",
722+
"input": {
723+
"text": "test query",
724+
"knowledgeBaseId": "test-kb-id",
725+
"overrideSearchType": "INVALID_TYPE",
726+
},
727+
}
728+
729+
result = retrieve.retrieve(tool=tool_use)
730+
731+
# Verify the result is an error
732+
assert result["status"] == "error"
733+
assert "Invalid overrideSearchType: INVALID_TYPE" in result["content"][0]["text"]
734+
assert "Supported values: HYBRID, SEMANTIC" in result["content"][0]["text"]
735+
736+
# Verify that boto3 client was not called
737+
mock_boto3_client.return_value.retrieve.assert_not_called()
738+
739+
740+
def test_retrieve_without_override_search_type(mock_boto3_client):
741+
"""Test retrieve without overrideSearchType (default behavior)."""
742+
tool_use = {
743+
"toolUseId": "test-tool-use-id",
744+
"input": {
745+
"text": "test query",
746+
"knowledgeBaseId": "test-kb-id",
747+
},
748+
}
749+
750+
result = retrieve.retrieve(tool=tool_use)
751+
752+
# Verify the result is successful
753+
assert result["status"] == "success"
754+
755+
# Verify that boto3 client was called without overrideSearchType
756+
mock_boto3_client.return_value.retrieve.assert_called_once_with(
757+
retrievalQuery={"text": "test query"},
758+
knowledgeBaseId="test-kb-id",
759+
retrievalConfiguration={
760+
"vectorSearchConfiguration": {
761+
"numberOfResults": 10
762+
}
763+
},
764+
)
765+
766+
767+
def test_retrieve_with_override_search_type_and_filter(mock_boto3_client):
768+
"""Test retrieve with both overrideSearchType and retrieveFilter."""
769+
tool_use = {
770+
"toolUseId": "test-tool-use-id",
771+
"input": {
772+
"text": "test query",
773+
"knowledgeBaseId": "test-kb-id",
774+
"overrideSearchType": "HYBRID",
775+
"retrieveFilter": {"equals": {"key": "category", "value": "security"}},
776+
},
777+
}
778+
779+
result = retrieve.retrieve(tool=tool_use)
780+
781+
# Verify the result is successful
782+
assert result["status"] == "success"
783+
784+
# Verify that boto3 client was called with both overrideSearchType and filter
785+
mock_boto3_client.return_value.retrieve.assert_called_once_with(
786+
retrievalQuery={"text": "test query"},
787+
knowledgeBaseId="test-kb-id",
788+
retrievalConfiguration={
789+
"vectorSearchConfiguration": {
790+
"numberOfResults": 10,
791+
"overrideSearchType": "HYBRID",
792+
"filter": {"equals": {"key": "category", "value": "security"}}
793+
}
794+
},
795+
)
796+
797+
798+
def test_retrieve_via_agent_with_override_search_type(agent, mock_boto3_client):
799+
"""Test retrieving via the agent interface with overrideSearchType."""
800+
with mock.patch.dict(os.environ, {"KNOWLEDGE_BASE_ID": "agent-kb-id"}):
801+
result = agent.tool.retrieve(
802+
text="agent query",
803+
knowledgeBaseId="test-kb-id",
804+
overrideSearchType="HYBRID"
805+
)
806+
807+
result_text = extract_result_text(result)
808+
assert "Retrieved" in result_text
809+
assert "results with score >=" in result_text
810+
811+
# Verify the boto3 client was called with overrideSearchType
812+
mock_boto3_client.return_value.retrieve.assert_called_once_with(
813+
retrievalQuery={"text": "agent query"},
814+
knowledgeBaseId="test-kb-id",
815+
retrievalConfiguration={
816+
"vectorSearchConfiguration": {
817+
"numberOfResults": 10,
818+
"overrideSearchType": "HYBRID"
819+
}
820+
},
821+
)
822+
823+
659824
def test_retrieve_via_agent_with_enable_metadata(agent, mock_boto3_client):
660825
"""Test retrieving via the agent interface with enableMetadata."""
661826
with mock.patch.dict(os.environ, {"KNOWLEDGE_BASE_ID": "agent-kb-id"}):
@@ -677,3 +842,5 @@ def test_retrieve_via_agent_with_enable_metadata(agent, mock_boto3_client):
677842
assert "results with score >=" in result_text
678843
assert "Metadata:" not in result_text
679844
assert "test-source" not in result_text
845+
846+

0 commit comments

Comments
 (0)