Skip to content

Feature: Allow for templating of nested objects in instructions #574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 75 additions & 23 deletions src/google/adk/flows/llm_flows/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import re
from typing import AsyncGenerator
from typing import Generator
from typing import TYPE_CHECKING

from typing_extensions import override
Expand Down Expand Up @@ -77,7 +76,52 @@ async def _populate_values(
instruction_template: str,
context: InvocationContext,
) -> str:
"""Populates values in the instruction template, e.g. state, artifact, etc."""
"""Populates values in the instruction template, e.g. state, artifact, etc.

Supports nested dot-separated references like:
- state.user.name
- artifact.config.settings
- user.profile.email
- user?.profile?.name? (optional markers at any level)
"""

def _get_nested_value(
obj, paths: list[str], is_optional: bool = False
) -> str:
"""Gets a nested value from an object using a list of path segments.

Args:
obj: The object to get the value from
paths: List of path segments to traverse
is_optional: Whether the entire path is optional

Returns:
The value as a string

Raises:
KeyError: If the path doesn't exist and the reference is not optional
"""
if not paths:
return str(obj)

# Get current part and remaining paths
current_part = paths[0]

# Handle optional markers
is_current_optional = current_part.endswith('?') or is_optional
clean_part = current_part.removesuffix('?')

# Get value for current part
if isinstance(obj, dict) and clean_part in obj:
return _get_nested_value(obj[clean_part], paths[1:], is_current_optional)
elif hasattr(obj, clean_part):
return _get_nested_value(
getattr(obj, clean_part), paths[1:], is_current_optional
)
elif is_current_optional:
return ''
else:
raise KeyError(f'Key not found: {clean_part}')

async def _async_sub(pattern, repl_async_fn, string) -> str:
result = []
Expand All @@ -96,29 +140,37 @@ async def _replace_match(match) -> str:
if var_name.endswith('?'):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
if context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
artifact = await context.artifact_service.load_artifact(
app_name=context.session.app_name,
user_id=context.session.user_id,
session_id=context.session.id,
filename=var_name,
)
if not var_name:
raise KeyError(f'Artifact {var_name} not found.')
return str(artifact)
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in context.session.state:
return str(context.session.state[var_name])

try:
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
if context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
artifact = await context.artifact_service.load_artifact(
app_name=context.session.app_name,
user_id=context.session.user_id,
session_id=context.session.id,
filename=var_name,
)
if not var_name:
raise KeyError(f'Artifact {var_name} not found.')
return str(artifact)
else:
if optional:
return ''
else:
if not _is_valid_state_name(var_name.split('.')[0].removesuffix('?')):
return match.group()
# Try to resolve nested path
try:
return _get_nested_value(
context.session.state, var_name.split('.'), optional
)
except KeyError:
if not _is_valid_state_name(var_name):
return match.group()
raise KeyError(f'Context variable not found: `{var_name}`.')
except Exception as e:
if optional:
return ''
raise e

return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)

Expand Down
55 changes: 49 additions & 6 deletions tests/unittests/flows/llm_flows/test_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.genai import types
import pytest

from google.adk.agents import Agent
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.flows.llm_flows import instructions
from google.adk.models import LlmRequest
from google.adk.sessions import Session
from google.genai import types
import pytest

from ... import utils

Expand All @@ -33,15 +34,21 @@ async def test_build_system_instruction():
model="gemini-1.5-flash",
name="agent",
instruction=("""Use the echo_info tool to echo { customerId }, \
{{customer_int }, { non-identifier-float}}, \
{'key1': 'value1'} and {{'key2': 'value2'}}."""),
{{customer_int }, {customer.profile.name}, {customer?.preferences.alias}, \
{ non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""),
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
state={
"customerId": "1234567890",
"customer_int": 30,
"customer": {
"profile": {"name": "Test User", "email": "[email protected]"}
},
},
)

async for _ in instructions.request_processor.run_async(
Expand All @@ -52,7 +59,7 @@ async def test_build_system_instruction():

assert request.config.system_instruction == (
"""Use the echo_info tool to echo 1234567890, 30, \
{ non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""
Test User, , { non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""
)


Expand Down Expand Up @@ -162,3 +169,39 @@ async def test_build_system_instruction_with_namespace():
assert request.config.system_instruction == (
"""Use the echo_info tool to echo 1234567890, app_value, user_value, {a:key}."""
)


@pytest.mark.asyncio
async def test_nested_templating():
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
instruction=(
"""Echo the following: {user.profile.name}, {user.profile.email}, {user.settings?.preferences.theme}, {user.preferences.value}"""
),
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={
"user": {
"profile": {"name": "Test User", "email": "[email protected]"}
}
},
)

async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass

assert request.config.system_instruction == (
"""Echo the following: Test User, [email protected], , {user.preferences.value}"""
)