Skip to content

Commit 321f8af

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix/gemini-schema-const
2 parents 8285130 + f9097cb commit 321f8af

8 files changed

Lines changed: 385 additions & 10 deletions

File tree

src/google/adk/a2a/converters/to_adk_event.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
# Logger
4444
logger = logging.getLogger("google_adk." + __name__)
4545

46+
MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT = (
47+
"mock_function_call_for_required_user_input"
48+
)
49+
4650
A2AMessageToEventConverter = Callable[
4751
[
4852
Message,
@@ -276,6 +280,36 @@ def _merge_event_actions(
276280
return EventActions.model_validate(merged_actions_data)
277281

278282

283+
def _create_mock_function_call_for_required_user_input(
284+
state: TaskState,
285+
output_parts: list[genai_types.Part],
286+
long_running_function_ids: set[str],
287+
) -> tuple[list[genai_types.Part], set[str]]:
288+
"""Creates a mock function call for input/auth-required if applicable.
289+
290+
This solution allows to unblock the A2A integration with non-ADK agents from
291+
ADK side by replacing the last text part with a synthetic function call. All
292+
other parts are preserved.
293+
"""
294+
if (
295+
state == TaskState.input_required or state == TaskState.auth_required
296+
) and (not long_running_function_ids or len(long_running_function_ids) == 0):
297+
# Find the last text part from the bottom to replace it with a function call.
298+
# In case of input-required events, the LLM should stop the production of other parts.
299+
for i in range(len(output_parts) - 1, -1, -1):
300+
if output_parts[i].text:
301+
function_call = genai_types.FunctionCall(
302+
id=str(uuid.uuid4()),
303+
name=MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT,
304+
args={"input_required": output_parts[i].text},
305+
)
306+
long_running_function_ids = set()
307+
long_running_function_ids.add(function_call.id)
308+
output_parts[i] = genai_types.Part(function_call=function_call)
309+
break
310+
return output_parts, long_running_function_ids
311+
312+
279313
@a2a_experimental
280314
def convert_a2a_task_to_event(
281315
a2a_task: Task,
@@ -317,9 +351,9 @@ def convert_a2a_task_to_event(
317351
output_parts, _ = _convert_a2a_parts_to_adk_parts(
318352
artifact_parts, part_converter
319353
)
320-
if (
321-
a2a_task.status.message
322-
and a2a_task.status.state == TaskState.input_required
354+
if a2a_task.status.message and (
355+
a2a_task.status.state == TaskState.input_required
356+
or a2a_task.status.state == TaskState.auth_required
323357
):
324358
event_actions = _merge_event_actions(
325359
event_actions,
@@ -331,6 +365,12 @@ def convert_a2a_task_to_event(
331365
output_parts.extend(parts)
332366
long_running_function_ids.update(ids)
333367

368+
output_parts, long_running_function_ids = (
369+
_create_mock_function_call_for_required_user_input(
370+
a2a_task.status.state, output_parts, long_running_function_ids
371+
)
372+
)
373+
334374
return _create_event(
335375
output_parts,
336376
invocation_context,
@@ -422,6 +462,14 @@ def convert_a2a_status_update_to_event(
422462
output_parts.extend(parts)
423463
long_running_function_ids.update(ids)
424464

465+
output_parts, long_running_function_ids = (
466+
_create_mock_function_call_for_required_user_input(
467+
a2a_status_update.status.state,
468+
output_parts,
469+
long_running_function_ids,
470+
)
471+
)
472+
425473
return _create_event(
426474
output_parts,
427475
invocation_context,

src/google/adk/agents/remote_a2a_agent.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
from ..a2a.converters.part_converter import convert_a2a_part_to_genai_part
6666
from ..a2a.converters.part_converter import convert_genai_part_to_a2a_part
6767
from ..a2a.converters.part_converter import GenAIPartToA2APartConverter
68+
from ..a2a.converters.to_adk_event import _create_mock_function_call_for_required_user_input
69+
from ..a2a.converters.to_adk_event import MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT
6870
from ..a2a.converters.utils import _get_adk_metadata_key
6971
from ..a2a.experimental import a2a_experimental
7072
from ..a2a.logs.log_utils import build_a2a_request_log
@@ -105,6 +107,22 @@ class A2AClientError(Exception):
105107
pass
106108

107109

110+
def _add_mock_function_call(event: Event, state: TaskState) -> None:
111+
"""Generates a mock function call for input-required events if applicable."""
112+
if event.content is None:
113+
return
114+
115+
output_parts, long_running_tool_ids = (
116+
_create_mock_function_call_for_required_user_input(
117+
state,
118+
event.content.parts,
119+
event.long_running_tool_ids,
120+
)
121+
)
122+
event.content.parts = output_parts
123+
event.long_running_tool_ids = long_running_tool_ids
124+
125+
108126
@a2a_experimental
109127
class RemoteA2aAgent(BaseAgent):
110128
"""Agent that communicates with a remote A2A agent via A2A client.
@@ -360,8 +378,40 @@ def _create_a2a_request_for_user_function_response(
360378
if not function_call_event:
361379
return None
362380

381+
event = ctx.session.events[-1]
382+
# If the user function_response replies to a function_call for non-ADK
383+
# input-required events (fc.name = MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT),
384+
# the function_response part is replaced with text extracted from the
385+
# function response.
386+
# The implementation is based on the assumption that the user function_response
387+
# event will contain a function_response with the name
388+
# MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT and the response will
389+
# contain a "result" field with the user input as a string text.
390+
mock_function_call = [
391+
fc
392+
for fc in function_call_event.get_function_calls()
393+
if fc.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT
394+
]
395+
if mock_function_call:
396+
new_parts = []
397+
for function_response in event.get_function_responses():
398+
if (
399+
function_response.name == MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT
400+
and function_response.response
401+
and "result" in function_response.response
402+
):
403+
text_value = function_response.response.get("result")
404+
new_parts.append(
405+
genai_types.Part(
406+
text=str(text_value),
407+
)
408+
)
409+
new_event = event.model_copy(deep=True)
410+
new_event.content.parts = new_parts
411+
event = new_event
412+
363413
a2a_message = convert_event_to_a2a_message(
364-
ctx.session.events[-1], ctx, Role.user, self._genai_part_converter
414+
event, ctx, Role.user, self._genai_part_converter
365415
)
366416
if function_call_event.custom_metadata:
367417
metadata = function_call_event.custom_metadata
@@ -472,6 +522,7 @@ async def _handle_a2a_response(
472522
):
473523
for part in event.content.parts:
474524
part.thought = True
525+
_add_mock_function_call(event, task.status.state)
475526
elif (
476527
isinstance(update, A2ATaskStatusUpdateEvent)
477528
and update.status
@@ -487,6 +538,7 @@ async def _handle_a2a_response(
487538
):
488539
for part in event.content.parts:
489540
part.thought = True
541+
_add_mock_function_call(event, update.status.state)
490542
elif isinstance(update, A2ATaskArtifactUpdateEvent) and (
491543
not update.append or update.last_chunk
492544
):

src/google/adk/integrations/agent_registry/agent_registry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,10 @@ def _get_auth_headers(self) -> Dict[str, str]:
208208
"Authorization": f"Bearer {self._credentials.token}",
209209
"Content-Type": "application/json",
210210
}
211-
quota_project_id = getattr(self._credentials, "quota_project_id", None)
211+
quota_project_id = (
212+
getattr(self._credentials, "quota_project_id", None)
213+
or self.project_id
214+
)
212215
if quota_project_id:
213216
headers["x-goog-user-project"] = quota_project_id
214217
return headers

src/google/adk/tools/skill_toolset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,8 @@ def __init__(
895895
script_timeout: Timeout in seconds for shell script execution via
896896
subprocess.run. Defaults to 300 seconds. Does not apply to Python
897897
scripts executed via exec().
898+
additional_tools: Optional list of `BaseTool` or `BaseToolset` instances
899+
to be made available to the agent when certain skills are activated.
898900
"""
899901
super().__init__()
900902

@@ -911,6 +913,8 @@ def __init__(
911913
self._registry = registry
912914
self._code_executor = code_executor
913915
self._script_timeout = script_timeout
916+
# Needed for mid-turn reloading of skill tools.
917+
self._use_invocation_cache = False
914918
self._invocation_cache: dict[
915919
str,
916920
dict[str, models.Skill | asyncio.Future[models.Skill | None] | None],

tests/unittests/a2a/converters/test_to_adk.py

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.adk.a2a.converters.to_adk_event import convert_a2a_message_to_event
3131
from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event
3232
from google.adk.a2a.converters.to_adk_event import convert_a2a_task_to_event
33+
from google.adk.a2a.converters.to_adk_event import MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT
3334
from google.adk.a2a.converters.utils import _get_adk_metadata_key
3435
from google.adk.agents.invocation_context import InvocationContext
3536
from google.genai import types as genai_types
@@ -330,12 +331,95 @@ def test_convert_a2a_task_to_event_merges_status_and_artifact_actions(self):
330331
assert event.actions.state_delta == {"saved_key": "saved-value"}
331332
assert event.actions.transfer_to_agent == "agent-2"
332333
assert event.content is not None
333-
assert event.content.parts == [mock_genai_part]
334+
assert (
335+
event.content.parts[0].function_call.name
336+
== MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT
337+
)
338+
assert (
339+
event.content.parts[0].function_call.args["input_required"]
340+
== "need input"
341+
)
342+
343+
def test_convert_a2a_task_to_event_multiple_parts_replaces_last_text(self):
344+
"""Test converting A2A task with multiple text parts, only replacing the last text."""
345+
part1 = Mock(spec=A2APart)
346+
part1.root = Mock(spec=TextPart)
347+
part1.root.metadata = {}
348+
part2 = Mock(spec=A2APart)
349+
part2.root = Mock(spec=TextPart)
350+
part2.root.metadata = {}
351+
352+
task = Task(
353+
id="task-1",
354+
context_id="context-1",
355+
kind="task",
356+
status=TaskStatus(
357+
state=TaskState.input_required,
358+
timestamp="now",
359+
message=Message(
360+
message_id="m1",
361+
role="agent",
362+
parts=[part1, part2],
363+
),
364+
),
365+
)
366+
367+
mock_genai_part_1 = genai_types.Part.from_text(text="Part 1")
368+
mock_genai_part_2 = genai_types.Part.from_text(text="Part 2")
334369

335-
def test_convert_a2a_task_to_event_none(self):
336-
"""Test convert_a2a_task_to_event with None."""
337-
with pytest.raises(ValueError, match="A2A task cannot be None"):
338-
convert_a2a_task_to_event(None)
370+
part_converter_mock = Mock()
371+
part_converter_mock.side_effect = [[mock_genai_part_1], [mock_genai_part_2]]
372+
373+
event = convert_a2a_task_to_event(
374+
task,
375+
author="test-author",
376+
invocation_context=self.mock_context,
377+
part_converter=part_converter_mock,
378+
)
379+
380+
assert event is not None
381+
assert event.content is not None
382+
assert len(event.content.parts) == 2
383+
assert event.content.parts[0].text == "Part 1"
384+
assert (
385+
event.content.parts[1].function_call.name
386+
== MOCK_FUNCTION_CALL_FOR_REQUIRED_USER_INPUT
387+
)
388+
389+
def test_convert_a2a_task_to_event_no_text_parts(self):
390+
"""Test converting A2A task with no text parts should not inject function call."""
391+
part1 = Mock(spec=A2APart)
392+
part1.root = Mock() # Not a TextPart
393+
part1.root.metadata = {}
394+
395+
task = Task(
396+
id="task-1",
397+
context_id="context-1",
398+
kind="task",
399+
status=TaskStatus(
400+
state=TaskState.input_required,
401+
timestamp="now",
402+
message=Message(
403+
message_id="m1",
404+
role="agent",
405+
parts=[part1],
406+
),
407+
),
408+
)
409+
mock_image_part = genai_types.Part(
410+
inline_data=genai_types.Blob(mime_type="image/jpeg", data=b"fake")
411+
)
412+
413+
event = convert_a2a_task_to_event(
414+
task,
415+
author="test-author",
416+
invocation_context=self.mock_context,
417+
part_converter=Mock(return_value=[mock_image_part]),
418+
)
419+
420+
assert event is not None
421+
assert event.content is not None
422+
assert event.content.parts == [mock_image_part]
339423

340424
def test_convert_a2a_status_update_to_event_success(self):
341425
"""Test successful conversion of A2A status update to Event."""

0 commit comments

Comments
 (0)