Skip to content

Commit f22bac0

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: add Spanner execute sql query result mode
Add using the execute sql query return result as list of dictionaries. In each dictionary the key is the column name and the value is the value of the that column in a given row. PiperOrigin-RevId: 840909555
1 parent de841a4 commit f22bac0

File tree

7 files changed

+406
-11
lines changed

7 files changed

+406
-11
lines changed

contributing/samples/spanner/agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from google.adk.auth.auth_credential import AuthCredentialTypes
1919
from google.adk.tools.google_tool import GoogleTool
2020
from google.adk.tools.spanner.settings import Capabilities
21+
from google.adk.tools.spanner.settings import QueryResultMode
2122
from google.adk.tools.spanner.settings import SpannerToolSettings
2223
from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig
2324
from google.adk.tools.spanner.spanner_toolset import SpannerToolset
@@ -34,7 +35,10 @@
3435

3536

3637
# Define Spanner tool config with read capability set to allowed.
37-
tool_settings = SpannerToolSettings(capabilities=[Capabilities.DATA_READ])
38+
tool_settings = SpannerToolSettings(
39+
capabilities=[Capabilities.DATA_READ],
40+
query_result_mode=QueryResultMode.DICT_LIST,
41+
)
3842

3943
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
4044
# Initialize the tools to do interactive OAuth

src/google/adk/tools/spanner/query_tool.py

Lines changed: 130 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,16 @@
1414

1515
from __future__ import annotations
1616

17+
import functools
18+
import textwrap
19+
import types
20+
from typing import Callable
21+
1722
from google.auth.credentials import Credentials
1823

1924
from . import utils
2025
from ..tool_context import ToolContext
26+
from .settings import QueryResultMode
2127
from .settings import SpannerToolSettings
2228

2329

@@ -49,16 +55,29 @@ def execute_sql(
4955
query not returned in the result.
5056
5157
Examples:
52-
Fetch data or insights from a table:
58+
<Example>
59+
>>> execute_sql("my_project", "my_instance", "my_database",
60+
... "SELECT COUNT(*) AS count FROM my_table")
61+
{
62+
"status": "SUCCESS",
63+
"rows": [
64+
[100]
65+
]
66+
}
67+
</Example>
5368
54-
>>> execute_sql("my_project", "my_instance", "my_database",
55-
... "SELECT COUNT(*) AS count FROM my_table")
56-
{
57-
"status": "SUCCESS",
58-
"rows": [
59-
[100]
60-
]
61-
}
69+
<Example>
70+
>>> execute_sql("my_project", "my_instance", "my_database",
71+
... "SELECT name, rating, description FROM hotels_table")
72+
{
73+
"status": "SUCCESS",
74+
"rows": [
75+
["The Hotel", 4.1, "Modern hotel."],
76+
["Park Inn", 4.5, "Cozy hotel."],
77+
...
78+
]
79+
}
80+
</Example>
6281
6382
Note:
6483
This is running with Read-Only Transaction for query that only read data.
@@ -72,3 +91,105 @@ def execute_sql(
7291
settings,
7392
tool_context,
7493
)
94+
95+
96+
_EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING = textwrap.dedent("""\
97+
Run a Spanner Read-Only query in the spanner database and return the result.
98+
99+
Args:
100+
project_id (str): The GCP project id in which the spanner database
101+
resides.
102+
instance_id (str): The instance id of the spanner database.
103+
database_id (str): The database id of the spanner database.
104+
query (str): The Spanner SQL query to be executed.
105+
credentials (Credentials): The credentials to use for the request.
106+
settings (SpannerToolSettings): The settings for the tool.
107+
tool_context (ToolContext): The context for the tool.
108+
109+
Returns:
110+
dict: Dictionary with the result of the query.
111+
If the result contains the key "result_is_likely_truncated" with
112+
value True, it means that there may be additional rows matching the
113+
query not returned in the result.
114+
115+
Examples:
116+
<Example>
117+
>>> execute_sql("my_project", "my_instance", "my_database",
118+
... "SELECT COUNT(*) AS count FROM my_table")
119+
{
120+
"status": "SUCCESS",
121+
"rows": [
122+
{
123+
"count": 100
124+
}
125+
]
126+
}
127+
</Example>
128+
129+
<Example>
130+
>>> execute_sql("my_project", "my_instance", "my_database",
131+
... "SELECT COUNT(*) FROM my_table")
132+
{
133+
"status": "SUCCESS",
134+
"rows": [
135+
{
136+
"": 100
137+
}
138+
]
139+
}
140+
</Example>
141+
142+
<Example>
143+
>>> execute_sql("my_project", "my_instance", "my_database",
144+
... "SELECT name, rating, description FROM hotels_table")
145+
{
146+
"status": "SUCCESS",
147+
"rows": [
148+
{
149+
"name": "The Hotel",
150+
"rating": 4.1,
151+
"description": "Modern hotel."
152+
},
153+
{
154+
"name": "Park Inn",
155+
"rating": 4.5,
156+
"description": "Cozy hotel."
157+
},
158+
...
159+
]
160+
}
161+
</Example>
162+
163+
Note:
164+
This is running with Read-Only Transaction for query that only read data.
165+
""")
166+
167+
168+
def get_execute_sql(settings: SpannerToolSettings) -> Callable[..., dict]:
169+
"""Get the execute_sql tool customized as per the given tool settings.
170+
171+
Args:
172+
settings: Spanner tool settings indicating the behavior of the execute_sql
173+
tool.
174+
175+
Returns:
176+
callable[..., dict]: A version of the execute_sql tool respecting the tool
177+
settings.
178+
"""
179+
180+
if settings and settings.query_result_mode is QueryResultMode.DICT_LIST:
181+
182+
execute_sql_wrapper = types.FunctionType(
183+
execute_sql.__code__,
184+
execute_sql.__globals__,
185+
execute_sql.__name__,
186+
execute_sql.__defaults__,
187+
execute_sql.__closure__,
188+
)
189+
functools.update_wrapper(execute_sql_wrapper, execute_sql)
190+
# Update with the new docstring
191+
execute_sql_wrapper.__doc__ = _EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING
192+
return execute_sql_wrapper
193+
194+
# Return the default execute_sql function.
195+
return execute_sql

src/google/adk/tools/spanner/settings.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ class Capabilities(Enum):
4040
"""Read only data operations tools are allowed."""
4141

4242

43+
class QueryResultMode(Enum):
44+
"""Settings for Spanner execute sql query result."""
45+
46+
DEFAULT = "default"
47+
"""Return the result of a query as a list of rows data."""
48+
49+
DICT_LIST = "dict_list"
50+
"""Return the result of a query as a list of dictionaries.
51+
52+
In each dictionary the key is the column name and the value is the value of
53+
the that column in a given row.
54+
"""
55+
56+
4357
class SpannerVectorStoreSettings(BaseModel):
4458
"""Settings for Spanner Vector Store.
4559
@@ -140,5 +154,8 @@ class SpannerToolSettings(BaseModel):
140154
max_executed_query_result_rows: int = 50
141155
"""Maximum number of rows to return from a query result."""
142156

157+
query_result_mode: QueryResultMode = QueryResultMode.DEFAULT
158+
"""Mode for Spanner execute sql query result."""
159+
143160
vector_store_settings: Optional[SpannerVectorStoreSettings] = None
144161
"""Settings for Spanner vector store and vector similarity search."""

src/google/adk/tools/spanner/spanner_toolset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ async def get_tools(
111111
):
112112
all_tools.append(
113113
GoogleTool(
114-
func=query_tool.execute_sql,
114+
func=query_tool.get_execute_sql(self._tool_settings),
115115
credentials_config=self._credentials_config,
116116
tool_settings=self._tool_settings,
117117
)

src/google/adk/tools/spanner/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from . import client
2424
from ..tool_context import ToolContext
25+
from .settings import QueryResultMode
2526
from .settings import SpannerToolSettings
2627

2728
DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50
@@ -84,6 +85,9 @@ def execute_sql(
8485
if settings and settings.max_executed_query_result_rows > 0
8586
else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS
8687
)
88+
if settings and settings.query_result_mode is QueryResultMode.DICT_LIST:
89+
result_set = result_set.to_dict_list()
90+
8791
for row in result_set:
8892
try:
8993
# if the json serialization of the row succeeds, use it as is

0 commit comments

Comments
 (0)