Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

GITHUB_BASE_URL = "https://api.github.com"
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN")
if not GITHUB_TOKEN:
raise ValueError("GITHUB_TOKEN environment variable not set")

OWNER = os.getenv("OWNER", "google")
REPO = os.getenv("REPO", "adk-python")
Expand Down
27 changes: 19 additions & 8 deletions contributing/samples/adk_team/adk_issue_monitoring_agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,35 @@ def _increment_api_call_count() -> None:
allowed_methods=["GET", "DELETE"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
_session = requests.Session()
_session.mount("https://", adapter)
_session.headers.update({
"Authorization": f"token {GITHUB_TOKEN}",
"Accept": "application/vnd.github.v3+json",
})
_session = None


def _get_session() -> requests.Session:
global _session
if _session is not None:
return _session
if not GITHUB_TOKEN:
raise ValueError("GITHUB_TOKEN environment variable not set")
session = requests.Session()
session.mount("https://", adapter)
session.headers.update({
"Authorization": f"token {GITHUB_TOKEN}",
"Accept": "application/vnd.github.v3+json",
})
_session = session
return _session


def get_request(url: str, params: dict[str, Any] | None = None) -> Any:
_increment_api_call_count()
response = _session.get(url, params=params or {}, timeout=60)
response = _get_session().get(url, params=params or {}, timeout=60)
response.raise_for_status()
return response.json()


def post_request(url: str, payload: Any) -> Any:
_increment_api_call_count()
response = _session.post(url, json=payload, timeout=60)
response = _get_session().post(url, json=payload, timeout=60)
response.raise_for_status()
return response.json()

Expand Down
26 changes: 22 additions & 4 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,7 @@ async def _get_completion_inputs(
Optional[List[Dict]],
Optional[Dict[str, Any]],
Optional[Dict],
Optional[str],
]:
"""Converts an LlmRequest to litellm inputs and extracts generation params.

Expand All @@ -1990,8 +1991,8 @@ async def _get_completion_inputs(
model: The model string to use for determining provider-specific behavior.

Returns:
The litellm inputs (message list, tool dictionary, response format and
generation params).
The litellm inputs (message list, tool dictionary, response format,
generation params, and tool_choice).
"""
_ensure_litellm_imported()

Expand Down Expand Up @@ -2066,7 +2067,21 @@ async def _get_completion_inputs(
if not generation_params:
generation_params = None

return messages, tools, response_format, generation_params
# 5. Extract tool_choice from tool_config
tool_choice: Optional[str] = None
if (
llm_request.config
and llm_request.config.tool_config
and llm_request.config.tool_config.function_calling_config
):
mode = llm_request.config.tool_config.function_calling_config.mode
if mode == types.FunctionCallingConfigMode.ANY:
tool_choice = "required"
elif mode == types.FunctionCallingConfigMode.NONE:
tool_choice = "none"
# AUTO → None (provider default)

return messages, tools, response_format, generation_params, tool_choice


def _build_function_declaration_log(
Expand Down Expand Up @@ -2327,7 +2342,7 @@ async def generate_content_async(
logger.debug(_build_request_log(llm_request))

effective_model = llm_request.model or self.model
messages, tools, response_format, generation_params = (
messages, tools, response_format, generation_params, tool_choice = (
await _get_completion_inputs(llm_request, effective_model)
)
normalized_messages = _normalize_ollama_chat_messages(
Expand Down Expand Up @@ -2359,6 +2374,9 @@ async def generate_content_async(
if generation_params:
completion_args.update(generation_params)

if tool_choice is not None:
completion_args["tool_choice"] = tool_choice

if stream:
text = ""
reasoning_parts: List[types.Part] = []
Expand Down
32 changes: 24 additions & 8 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,30 @@ async def run_async(
input_schema = _get_input_schema(self.agent)
if input_schema:
input_value = input_schema.model_validate(args)
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=input_value.model_dump_json(exclude_none=True)
)
],
)
json_payload = input_value.model_dump_json(exclude_none=True)
output_schema = _get_output_schema(self.agent)
if output_schema:
# Single-shot structured output mode: pass raw JSON, no ReAct wrapper.
content = types.Content(
role='user',
parts=[types.Part.from_text(text=json_payload)],
)
else:
# Tool-calling mode: wrap with ReAct-style prompt.
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=(
'Process the following structured request. Use your'
' available tools as needed to gather information or'
' perform actions before producing the final'
' response.\n\nRequest:\n'
+ json_payload
)
)
],
)
else:
content = types.Content(
role='user',
Expand Down
Loading
Loading