Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/google/adk/tools/discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
r'search_result_mode.*DOCUMENTS', re.IGNORECASE
)

_LOCATION_PATTERN = re.compile(r"/locations/([^/]+)/")
_DEFAULT_ENDPOINT = "discoveryengine.googleapis.com"


class SearchResultMode(enum.Enum):
"""Search result mode for discovery engine search."""
Expand Down Expand Up @@ -102,10 +105,19 @@ def __init__(

credentials, _ = google.auth.default()
quota_project_id = getattr(credentials, 'quota_project_id', None)
options = (
client_options.ClientOptions(quota_project_id=quota_project_id)
if quota_project_id
else None

resource_id = data_store_id or search_engine_id or ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The or "" is redundant here. The ValueError check on lines 60-65 ensures that either data_store_id or search_engine_id is a non-None string, so the expression data_store_id or search_engine_id will always evaluate to a string. Removing the fallback to an empty string makes the code slightly cleaner and relies on the existing validation.

Suggested change
resource_id = data_store_id or search_engine_id or ""
resource_id = data_store_id or search_engine_id

location_match = _LOCATION_PATTERN.search(resource_id)
location = location_match.group(1) if location_match else "global"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code/fix isn't crazy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kaligautier! Glad the approach looks reasonable. The core idea is straightforward — parse the location from the resource ID and route to the correct regional endpoint instead of always defaulting to global.

api_endpoint = (
f"{location}-{_DEFAULT_ENDPOINT}"
if location != "global"
else _DEFAULT_ENDPOINT
)

options = client_options.ClientOptions(
api_endpoint=api_endpoint,
quota_project_id=quota_project_id,
)
self._discovery_engine_client = discoveryengine.SearchServiceClient(
credentials=credentials, client_options=options
Expand Down
71 changes: 70 additions & 1 deletion tests/unittests/tools/test_discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def test_discovery_engine_search_success(
assert result["results"][0]["content"] == "Test Content"
mock_auth.assert_called_once()
mock_client_options.ClientOptions.assert_called_once_with(
quota_project_id="test-quota-project"
api_endpoint="discoveryengine.googleapis.com",
quota_project_id="test-quota-project",
)
mock_search_client.assert_called_once_with(
credentials=mock_credentials,
Expand Down Expand Up @@ -320,3 +321,71 @@ def test_auto_detect_does_not_retry_on_unrelated_error(
assert result["status"] == "error"
assert "Permission denied" in result["error_message"]
assert mock_search_client.return_value.search.call_count == 1

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_regional_endpoint_eu_data_store(
self, mock_search_client, mock_client_options
):
"""Test that an EU data store uses the EU regional endpoint."""
DiscoveryEngineSearchTool(
data_store_id="projects/my-project/locations/eu/collections/default_collection/dataStores/my-ds"
)
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="eu-discoveryengine.googleapis.com",
quota_project_id=None,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_regional_endpoint_us_search_engine(
self, mock_search_client, mock_client_options
):
"""Test that a US search engine uses the US regional endpoint."""
DiscoveryEngineSearchTool(
search_engine_id="projects/my-project/locations/us/collections/default_collection/engines/my-engine"
)
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="us-discoveryengine.googleapis.com",
quota_project_id=None,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_regional_endpoint_single_region(
self, mock_search_client, mock_client_options
):
"""Test that a single-region location uses the correct endpoint."""
DiscoveryEngineSearchTool(
data_store_id="projects/my-project/locations/europe-west1/collections/default_collection/dataStores/my-ds"
)
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="europe-west1-discoveryengine.googleapis.com",
quota_project_id=None,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_global_endpoint_explicit(
self, mock_search_client, mock_client_options
):
"""Test that a global data store uses the default global endpoint."""
DiscoveryEngineSearchTool(
data_store_id="projects/my-project/locations/global/collections/default_collection/dataStores/my-ds"
)
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="discoveryengine.googleapis.com",
quota_project_id=None,
)

@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(discoveryengine, "SearchServiceClient")
def test_global_endpoint_no_location_in_id(
self, mock_search_client, mock_client_options
):
"""Test that a short ID without location falls back to global endpoint."""
DiscoveryEngineSearchTool(data_store_id="test_data_store")
mock_client_options.ClientOptions.assert_called_once_with(
api_endpoint="discoveryengine.googleapis.com",
quota_project_id=None,
)
Comment on lines +325 to +391
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These five tests for endpoint selection are very similar. They can be consolidated into a single, more maintainable test by using pytest.mark.parametrize. This approach reduces code duplication and makes it clearer what is being tested across different inputs.

  @pytest.mark.parametrize(
      ("resource_id_key", "resource_id_value", "expected_endpoint"),
      [
          (
              "data_store_id",
              "projects/my-project/locations/eu/collections/default_collection/dataStores/my-ds",
              "eu-discoveryengine.googleapis.com",
          ),
          (
              "search_engine_id",
              "projects/my-project/locations/us/collections/default_collection/engines/my-engine",
              "us-discoveryengine.googleapis.com",
          ),
          (
              "data_store_id",
              "projects/my-project/locations/europe-west1/collections/default_collection/dataStores/my-ds",
              "europe-west1-discoveryengine.googleapis.com",
          ),
          (
              "data_store_id",
              "projects/my-project/locations/global/collections/default_collection/dataStores/my-ds",
              "discoveryengine.googleapis.com",
          ),
          ("data_store_id", "test_data_store", "discoveryengine.googleapis.com"),
      ],
  )
  @mock.patch.object(discovery_engine_search_tool, "client_options")
  @mock.patch.object(discoveryengine, "SearchServiceClient")
  def test_endpoint_selection(
      self,
      mock_search_client,
      mock_client_options,
      resource_id_key,
      resource_id_value,
      expected_endpoint,
  ):
    """Test that the correct API endpoint is selected based on resource ID."""
    kwargs = {resource_id_key: resource_id_value}
    DiscoveryEngineSearchTool(**kwargs)
    mock_client_options.ClientOptions.assert_called_once_with(
        api_endpoint=expected_endpoint,
        quota_project_id=None,
    )