From 64562ac343b1272dc4a3e4373964849cac960c77 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Thu, 21 May 2026 21:17:36 +0000 Subject: [PATCH 01/18] chore: Deploy ADK API server directly --- pyproject.toml | 208 +- src/google/adk/cli/adk_web_server.py | 2216 ++++++++++++++++- src/google/adk/cli/cli_deploy.py | 316 +-- src/google/adk/cli/cli_tools_click.py | 459 ++-- src/google/adk/cli/fast_api.py | 822 +++--- src/google/adk/cli/utils/_telemetry.py | 106 + src/google/adk/telemetry/_agent_engine.py | 106 + src/google/adk/telemetry/google_cloud.py | 139 +- tests/unittests/cli/test_fast_api.py | 435 +++- tests/unittests/cli/utils/test_cli_deploy.py | 199 +- .../unittests/telemetry/test_google_cloud.py | 13 +- 11 files changed, 3793 insertions(+), 1226 deletions(-) create mode 100644 src/google/adk/cli/utils/_telemetry.py create mode 100644 src/google/adk/telemetry/_agent_engine.py diff --git a/pyproject.toml b/pyproject.toml index 43ee273184..5ba4338903 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,84 +33,70 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ - "aiosqlite>=0.21", - "authlib>=1.6.6,<2", - "click>=8.1.8,<9", - "fastapi>=0.124.1,<1", - "google-auth[pyopenssl]>=2.47", - "google-genai>=1.72,<2", - "graphviz>=0.20.2,<1", - "httpx>=0.27,<1", - "jsonschema>=4.23,<5", - "opentelemetry-api>=1.36,<=1.41.1", + "aiosqlite>=0.21", # For SQLite database + "anyio>=4.9,<5", # For MCP Session Manager + "authlib>=1.6.6,<2", # For RestAPI Tool + "click>=8.1.8,<9", # For CLI tools + "fastapi>=0.124.1,<1", # FastAPI framework + "google-api-python-client>=2.157,<3", # Google API client discovery + "google-auth[pyopenssl]>=2.47", # Google Auth library + "google-cloud-aiplatform[agent-engines]>=1.148.1,<2", # For VertexAI integrations, e.g. example store. + "google-cloud-bigquery>=2.2", + "google-cloud-bigquery-storage>=2", + "google-cloud-bigtable>=2.32", # For Bigtable database + "google-cloud-dataplex>=1.7,<3", # For Dataplex Catalog Search tool + "google-cloud-discoveryengine>=0.13.12,<0.14", # For Discovery Engine Search Tool + "google-cloud-pubsub>=2,<3", # For Pub/Sub Tool + "google-cloud-secret-manager>=2.22,<3", # Fetching secrets in RestAPI Tool + "google-cloud-spanner>=3.56,<4", # For Spanner database + "google-cloud-speech>=2.30,<3", # For Audio Transcription + "google-cloud-storage>=2.18,<4", # For GCS Artifact service + "google-genai>=1.72,<2", # Google GenAI SDK + "graphviz>=0.20.2,<1", # Graphviz for graph rendering + "httpx>=0.27,<1", # HTTP client library + "jsonschema>=4.23,<5", # Agent Builder config validation + "mcp>=1.24,<2", # For MCP Toolset + "opentelemetry-api>=1.36,<=1.41.1", # OpenTelemetry - allow 1.39+ up to latest published at time of resolution. + "opentelemetry-exporter-gcp-logging>=1.9.0a0,<=1.12.0a0", + "opentelemetry-exporter-gcp-monitoring>=1.9.0a0,<2", + "opentelemetry-exporter-gcp-trace>=1.9,<2", + "opentelemetry-exporter-otlp-proto-http>=1.36", + "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", "opentelemetry-sdk>=1.36,<=1.41.1", - "packaging>=21", - "pydantic>=2.12,<3", - "python-dotenv>=1,<2", - # go/keep-sorted start - "pyyaml>=6.0.2,<7", + "pyarrow>=14", + "pydantic>=2.12,<3", # For data validation/models + "python-dateutil>=2.9.0.post0,<3", # For Vertext AI Session Service + "python-dotenv>=1,<2", # To manage environment variables + "pyyaml>=6.0.2,<7", # For APIHubToolset. "requests>=2.32.4,<3", - "starlette>=0.49.1,<1", - "tenacity>=9,<10", + "sqlalchemy>=2,<3", # SQL database ORM + "sqlalchemy-spanner>=1.14", # Spanner database session service + "starlette>=0.49.1,<1", # For FastAPI CLI + "tenacity>=9,<10", # For Retry management "typing-extensions>=4.5,<5", - "tzlocal>=5.3,<6", - "uvicorn>=0.34,<1", - "watchdog>=6,<7", - "websockets>=15.0.1,<16", - # go/keep-sorted end + "tzlocal>=5.3,<6", # Time zone utilities + "uvicorn>=0.34,<1", # ASGI server for FastAPI + "watchdog>=6,<7", # For file change detection and hot reload + "websockets>=15.0.1,<16", # For BaseLlmFlow + "yarl<1.24", ] - optional-dependencies.a2a = [ "a2a-sdk>=0.3.4,<0.4", ] optional-dependencies.agent-identity = [ "google-cloud-iamconnectorcredentials>=0.1,<0.2", ] -optional-dependencies.all = [ - "anyio>=4.9,<5", - "google-api-python-client>=2.157,<3", - "google-cloud-aiplatform[agent-engines]>=1.148.1,<2", - "google-cloud-bigquery>=2.2", - "google-cloud-bigquery-storage>=2", - "google-cloud-bigtable>=2.32", - "google-cloud-dataplex>=1.7,<3", - "google-cloud-discoveryengine>=0.13.12,<0.14", - "google-cloud-pubsub>=2,<3", - "google-cloud-resource-manager>=1.12,<2", - "google-cloud-secret-manager>=2.22,<3", - "google-cloud-spanner>=3.56,<4", - "google-cloud-speech>=2.30,<3", - "google-cloud-storage>=2.18,<4", - "mcp>=1.24,<2", - "opentelemetry-exporter-gcp-logging>=1.9.0a0,<=1.12.0a0", - "opentelemetry-exporter-gcp-monitoring>=1.9.0a0,<2", - "opentelemetry-exporter-gcp-trace>=1.9,<2", - "opentelemetry-exporter-otlp-proto-http>=1.36", - "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", - "pyarrow>=14", - "python-dateutil>=2.9.0.post0,<3", - "sqlalchemy>=2,<3", - "sqlalchemy-spanner>=1.14", -] - optional-dependencies.community = [ "google-adk-community", ] -optional-dependencies.db = [ - "sqlalchemy>=2,<3", - "sqlalchemy-spanner>=1.14", -] - optional-dependencies.dev = [ "flit>=3.10", + "isort>=6", "mypy>=1.15", "pre-commit>=4", "pyink>=25.12", "pylint>=2.6", - "tox>=4.23.2", - "tox-uv>=1.33.2", ] - optional-dependencies.docs = [ "autodoc-pydantic", "furo", @@ -120,7 +106,6 @@ optional-dependencies.docs = [ "sphinx-click", "sphinx-rtd-theme", ] - optional-dependencies.eval = [ "gepa>=0.1", "google-cloud-aiplatform[evaluation]>=1.148", @@ -129,7 +114,7 @@ optional-dependencies.eval = [ "rouge-score>=0.1.2", "tabulate>=0.9", ] - +# Optional extensions optional-dependencies.extensions = [ "anthropic>=0.78", # For anthropic model support; 0.78 introduced ThinkingConfigAdaptiveParam (required for Claude Opus 4.7). "beautifulsoup4>=3.2.2", # For load_web_page tool. @@ -137,105 +122,47 @@ optional-dependencies.extensions = [ "docker>=7", # For ContainerCodeExecutor "google-cloud-firestore>=2.11,<3", # For Firestore services "google-cloud-parametermanager>=0.4,<1", - "k8s-agent-sandbox>=0.1.1.post3", - "kubernetes>=29", - "langgraph>=0.2.60,<0.4.8", - "litellm>=1.83.7,<=1.83.14", - "llama-index-embeddings-google-genai>=0.3", - "llama-index-readers-file>=0.4", - "lxml>=5.3", - "pypika>=0.50", - "toolbox-adk>=1,<2", + "k8s-agent-sandbox>=0.1.1.post3", # For GkeCodeExecutor sandbox mode + "kubernetes>=29", # For GkeCodeExecutor + "langgraph>=0.2.60,<0.4.8", # For LangGraphAgent + "litellm>=1.83.7,<=1.83.14", # For LiteLlm class. Lower bound: 5 CVE patches (2026-04). Upper bound pinned to current latest; bump deliberately. See #5488. + "llama-index-embeddings-google-genai>=0.3", # For files retrieval using LlamaIndex. + "llama-index-readers-file>=0.4", # For retrieval using LlamaIndex. + "lxml>=5.3", # For load_web_page tool. + "pypika>=0.50", # For crewai->chromadb dependency + "toolbox-adk>=1,<2", # For tools.toolbox_toolset.ToolboxToolset +] +optional-dependencies.otel-gcp = [ + "opentelemetry-instrumentation-google-genai>=0.6b0,<1", + "opentelemetry-instrumentation-grpc>=0.43b0,<1", + "opentelemetry-instrumentation-httpx>=0.54b0,<1", ] - -optional-dependencies.gcp = [ - "google-cloud-aiplatform[agent-engines]>=1.148.1,<2", - "google-cloud-bigquery>=2.2", - "google-cloud-bigquery-storage>=2", - "google-cloud-bigtable>=2.32", - "google-cloud-dataplex>=1.7,<3", - "google-cloud-discoveryengine>=0.13.12,<0.14", - "google-cloud-pubsub>=2,<3", - "google-cloud-resource-manager>=1.12,<2", - "google-cloud-secret-manager>=2.22,<3", - "google-cloud-spanner>=3.56,<4", - "google-cloud-speech>=2.30,<3", - "google-cloud-storage>=2.18,<4", - "opentelemetry-exporter-gcp-logging>=1.9.0a0,<=1.12.0a0", - "opentelemetry-exporter-gcp-monitoring>=1.9.0a0,<2", - "opentelemetry-exporter-gcp-trace>=1.9,<2", - "opentelemetry-exporter-otlp-proto-http>=1.36", - "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", - "pyarrow>=14", - "python-dateutil>=2.9.0.post0,<3", -] - -optional-dependencies.mcp = [ - "anyio>=4.9,<5", - "mcp>=1.24,<2", -] - -optional-dependencies.otel-gcp = [ "opentelemetry-instrumentation-google-genai>=0.6b0,<1" ] optional-dependencies.slack = [ "slack-bolt>=1.22" ] optional-dependencies.test = [ "a2a-sdk>=0.3,<0.4", "anthropic>=0.78", # For anthropic model tests; 0.78 introduced ThinkingConfigAdaptiveParam (required for Claude Opus 4.7). - "anyio>=4.9,<5", "crewai[tools]; python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ - "gepa>=0.1", - "google-api-python-client>=2.157,<3", - "google-cloud-aiplatform[agent-engines,evaluation]>=1.148.1,<2", - "google-cloud-bigquery>=2.2", - "google-cloud-bigquery-storage>=2", - "google-cloud-bigtable>=2.32", - "google-cloud-dataplex>=1.7,<3", - "google-cloud-discoveryengine>=0.13.12,<0.14", "google-cloud-firestore>=2.11,<3", "google-cloud-iamconnectorcredentials>=0.1,<0.2", "google-cloud-parametermanager>=0.4,<1", - "google-cloud-pubsub>=2,<3", - "google-cloud-resource-manager>=1.12,<2", - "google-cloud-secret-manager>=2.22,<3", - "google-cloud-spanner>=3.56,<4", - "google-cloud-speech>=2.30,<3", - "google-cloud-storage>=2.18,<4", - "jinja2>=3.1.4,<4", - "kubernetes>=29", + "kubernetes>=29", # For GkeCodeExecutor "langchain-community>=0.3.17", - "langgraph>=0.2.60,<0.4.8", - "litellm>=1.83.7,<=1.83.14", - "llama-index-readers-file>=0.4", - "mcp>=1.24,<2", - "openai>=1.100.2", - "opentelemetry-exporter-gcp-logging>=1.9.0a0,<=1.12.0a0", - "opentelemetry-exporter-gcp-monitoring>=1.9.0a0,<2", - "opentelemetry-exporter-gcp-trace>=1.9,<2", - "opentelemetry-exporter-otlp-proto-http>=1.36", + "langgraph>=0.2.60,<0.4.8", # For LangGraphAgent + "litellm>=1.83.7,<=1.83.14", # For LiteLLM tests. Lower bound: 5 CVE patches (2026-04). Upper bound pinned to current latest; bump deliberately. See #5488. + "llama-index-readers-file>=0.4", # For retrieval tests + "openai>=1.100.2", # For LiteLLM "opentelemetry-instrumentation-google-genai>=0.3b0,<1", - "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", - "pandas>=2.2.3", - "pyarrow>=14", - "pypika>=0.50", + "pypika>=0.50", # For crewai->chromadb dependency "pytest>=9,<10", "pytest-asyncio>=0.25", "pytest-mock>=3.14", "pytest-xdist>=3.6.1", - "python-dateutil>=2.9.0.post0,<3", "python-multipart>=0.0.9", "rouge-score>=0.1.2", "slack-bolt>=1.22", - "sqlalchemy>=2,<3", - "sqlalchemy-spanner>=1.14", "tabulate>=0.9", - "tomli>=2,<3; python_version<'3.11'", ] - optional-dependencies.toolbox = [ "toolbox-adk>=1,<2" ] - -optional-dependencies.tools = [ - "google-api-python-client>=2.157,<3", -] - urls.changelog = "https://github.com/google/adk-python/blob/main/CHANGELOG.md" urls.documentation = "https://google.github.io/adk-docs/" urls.homepage = "https://google.github.io/adk-docs/" @@ -253,7 +180,7 @@ include = [ "py.typed" ] [tool.isort] profile = "google" single_line_exclusions = [ ] -line_length = 200 +line_length = 200 # Prevent line wrap flickering. known_third_party = [ "google.adk", "a2a" ] [tool.pytest.ini_options] @@ -262,7 +189,7 @@ asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" [tool.mypy] -python_version = "3.11" +python_version = "3.10" exclude = [ "tests/", "contributing/samples/" ] plugins = [ "pydantic.mypy" ] strict = true @@ -270,6 +197,7 @@ disable_error_code = [ "import-not-found", "import-untyped", "unused-ignore" ] follow_imports = "skip" [tool.pyink] +# Format py files following Google style-guide line-length = 80 unstable = true pyink-indentation = 2 diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index b567ce949d..bced893407 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -14,22 +14,2218 @@ from __future__ import annotations +import asyncio +from contextlib import asynccontextmanager +import importlib +import json import logging +import os +import re +import sys +import time +import traceback +import typing +from typing import Any +from typing import Callable +from typing import List +from typing import Literal +from typing import Optional +from fastapi import FastAPI +from fastapi import HTTPException +from fastapi import Query +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.responses import StreamingResponse +from fastapi.staticfiles import StaticFiles +from fastapi.websockets import WebSocket +from fastapi.websockets import WebSocketDisconnect +from google.genai import types +import graphviz +from opentelemetry import trace +import opentelemetry.sdk.environment_variables as otel_env +from opentelemetry.sdk.trace import export as export_lib +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace import SpanProcessor +from opentelemetry.sdk.trace import TracerProvider +from pydantic import Field +from pydantic import ValidationError +from starlette.types import Lifespan from typing_extensions import deprecated +from typing_extensions import override +from watchdog.observers import Observer +import yaml -from .api_server import _parse_cors_origins -from .api_server import RunAgentRequest -from .dev_server import DevServer +from . import agent_graph +from ..agents.base_agent import BaseAgent +from ..agents.live_request_queue import LiveRequest +from ..agents.live_request_queue import LiveRequestQueue +from ..agents.llm_agent import LlmAgent +from ..agents.run_config import RunConfig +from ..agents.run_config import StreamingMode +from ..apps.app import App +from ..artifacts.base_artifact_service import ArtifactVersion +from ..artifacts.base_artifact_service import BaseArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService +from ..errors.already_exists_error import AlreadyExistsError +from ..errors.input_validation_error import InputValidationError +from ..errors.not_found_error import NotFoundError +from ..errors.session_not_found_error import SessionNotFoundError +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from ..evaluation.eval_case import EvalCase +from ..evaluation.eval_case import SessionInput +from ..evaluation.eval_metrics import EvalMetric +from ..evaluation.eval_metrics import EvalMetricResult +from ..evaluation.eval_metrics import EvalMetricResultPerInvocation +from ..evaluation.eval_metrics import EvalStatus +from ..evaluation.eval_metrics import MetricInfo +from ..evaluation.eval_result import EvalSetResult +from ..evaluation.eval_set import EvalSet +from ..evaluation.eval_set_results_manager import EvalSetResultsManager +from ..evaluation.eval_sets_manager import EvalSetsManager +from ..events.event import Event +from ..memory.base_memory_service import BaseMemoryService +from ..plugins.base_plugin import BasePlugin +from ..runners import Runner +from ..sessions.base_session_service import BaseSessionService +from ..sessions.session import Session +from ..utils.agent_info import AgentInfo +from ..utils.agent_info import get_agents_dict +from ..utils.context_utils import Aclosing +from ..utils.feature_decorator import experimental +from ..version import __version__ +from .cli_eval import EVAL_SESSION_ID_PREFIX +from .utils import cleanup +from .utils import common +from .utils import envs +from .utils import evals +from .utils.base_agent_loader import BaseAgentLoader +from .utils.shared_value import SharedValue +from .utils.state import create_empty_state logger = logging.getLogger("google_adk." + __name__) +_EVAL_SET_FILE_EXTENSION = ".evalset.json" -@deprecated( - "AdkWebServer is deprecated and has been refactored into ApiServer and" - " DevServer. Use DevServer instead." -) -class AdkWebServer(DevServer): - """Deprecated wrapper class around DevServer for backward compatibility.""" +TAG_DEBUG = "Debug" +TAG_EVALUATION = "Evaluation" - pass +_REGEX_PREFIX = "regex:" + + +def _parse_cors_origins( + allow_origins: list[str], +) -> tuple[list[str], Optional[str]]: + """Parse allow_origins into literal origins and a combined regex pattern. + + Args: + allow_origins: List of origin strings. Entries prefixed with 'regex:' are + treated as regex patterns; all others are treated as literal origins. + + Returns: + A tuple of (literal_origins, combined_regex) where combined_regex is None + if no regex patterns were provided, or a single pattern joining all regex + patterns with '|'. + """ + literal_origins = [] + regex_patterns = [] + for origin in allow_origins: + if origin.startswith(_REGEX_PREFIX): + pattern = origin[len(_REGEX_PREFIX) :] + if pattern: + regex_patterns.append(pattern) + else: + literal_origins.append(origin) + + combined_regex = "|".join(regex_patterns) if regex_patterns else None + return literal_origins, combined_regex + + +def _is_origin_allowed( + origin: str, + allowed_literal_origins: list[str], + allowed_origin_regex: Optional[re.Pattern[str]], +) -> bool: + """Check whether the given origin matches the allowed origins.""" + if "*" in allowed_literal_origins: + return True + if origin in allowed_literal_origins: + return True + if allowed_origin_regex is not None: + return allowed_origin_regex.fullmatch(origin) is not None + return False + + +def _normalize_origin_scheme(scheme: str) -> str: + """Normalize request schemes to the browser Origin scheme space.""" + if scheme == "ws": + return "http" + if scheme == "wss": + return "https" + return scheme + + +def _strip_optional_quotes(value: str) -> str: + """Strip a single pair of wrapping quotes from a header value.""" + if len(value) >= 2 and value[0] == '"' and value[-1] == '"': + return value[1:-1] + return value + + +def _get_scope_header( + scope: dict[str, Any], header_name: bytes +) -> Optional[str]: + """Return the first matching header value from an ASGI scope.""" + for candidate_name, candidate_value in scope.get("headers", []): + if candidate_name == header_name: + return candidate_value.decode("latin-1").split(",", 1)[0].strip() + return None + + +def _get_request_origin(scope: dict[str, Any]) -> Optional[str]: + """Compute the effective origin for the current HTTP/WebSocket request.""" + forwarded = _get_scope_header(scope, b"forwarded") + if forwarded is not None: + proto = None + host = None + for element in forwarded.split(",", 1)[0].split(";"): + if "=" not in element: + continue + name, value = element.split("=", 1) + if name.strip().lower() == "proto": + proto = _strip_optional_quotes(value.strip()) + elif name.strip().lower() == "host": + host = _strip_optional_quotes(value.strip()) + if proto is not None and host is not None: + return f"{_normalize_origin_scheme(proto)}://{host}" + + host = _get_scope_header(scope, b"x-forwarded-host") + if host is None: + host = _get_scope_header(scope, b"host") + if host is None: + return None + + proto = _get_scope_header(scope, b"x-forwarded-proto") + if proto is None: + proto = scope.get("scheme", "http") + return f"{_normalize_origin_scheme(proto)}://{host}" + + +def _is_request_origin_allowed( + origin: str, + scope: dict[str, Any], + allowed_literal_origins: list[str], + allowed_origin_regex: Optional[re.Pattern[str]], + has_configured_allowed_origins: bool, +) -> bool: + """Validate an Origin header against explicit config or same-origin.""" + if has_configured_allowed_origins and _is_origin_allowed( + origin, allowed_literal_origins, allowed_origin_regex + ): + return True + + request_origin = _get_request_origin(scope) + if request_origin is None: + return False + return origin == request_origin + + +_SAFE_HTTP_METHODS = frozenset({"GET", "HEAD", "OPTIONS"}) + + +class _OriginCheckMiddleware: + """ASGI middleware that blocks cross-origin state-changing requests.""" + + def __init__( + self, + app: Any, + has_configured_allowed_origins: bool, + allowed_origins: list[str], + allowed_origin_regex: Optional[re.Pattern[str]], + ) -> None: + self._app = app + self._has_configured_allowed_origins = has_configured_allowed_origins + self._allowed_origins = allowed_origins + self._allowed_origin_regex = allowed_origin_regex + + async def __call__( + self, + scope: dict[str, Any], + receive: Any, + send: Any, + ) -> None: + if scope["type"] != "http": + await self._app(scope, receive, send) + return + + method = scope.get("method", "GET") + if method in _SAFE_HTTP_METHODS: + await self._app(scope, receive, send) + return + + origin = _get_scope_header(scope, b"origin") + if origin is None: + await self._app(scope, receive, send) + return + + if _is_request_origin_allowed( + origin, + scope, + self._allowed_origins, + self._allowed_origin_regex, + self._has_configured_allowed_origins, + ): + await self._app(scope, receive, send) + return + + response_body = b"Forbidden: origin not allowed" + await send({ + "type": "http.response.start", + "status": 403, + "headers": [ + (b"content-type", b"text/plain"), + (b"content-length", str(len(response_body)).encode()), + ], + }) + await send({ + "type": "http.response.body", + "body": response_body, + }) + + +class ApiServerSpanExporter(export_lib.SpanExporter): + + def __init__(self, trace_dict): + self.trace_dict = trace_dict + + def export( + self, spans: typing.Sequence[ReadableSpan] + ) -> export_lib.SpanExportResult: + for span in spans: + if ( + span.name == "call_llm" + or span.name == "send_data" + or span.name.startswith("execute_tool") + ): + attributes = dict(span.attributes) + attributes["trace_id"] = span.get_span_context().trace_id + attributes["span_id"] = span.get_span_context().span_id + if attributes.get("gcp.vertex.agent.event_id", None): + self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes + return export_lib.SpanExportResult.SUCCESS + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + +class InMemoryExporter(export_lib.SpanExporter): + + def __init__(self, trace_dict): + super().__init__() + self._spans = [] + self.trace_dict = trace_dict + + @override + def export( + self, spans: typing.Sequence[ReadableSpan] + ) -> export_lib.SpanExportResult: + for span in spans: + trace_id = span.context.trace_id + if span.name == "call_llm": + attributes = dict(span.attributes) + session_id = attributes.get("gcp.vertex.agent.session_id", None) + if session_id: + if session_id not in self.trace_dict: + self.trace_dict[session_id] = [trace_id] + else: + self.trace_dict[session_id] += [trace_id] + self._spans.extend(spans) + return export_lib.SpanExportResult.SUCCESS + + @override + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + def get_finished_spans(self, session_id: str): + trace_ids = self.trace_dict.get(session_id, None) + if trace_ids is None or not trace_ids: + return [] + return [x for x in self._spans if x.context.trace_id in trace_ids] + + def clear(self): + self._spans.clear() + + +class RunAgentRequest(common.BaseModel): + app_name: str + user_id: str + session_id: str + new_message: Optional[types.Content] = None + streaming: bool = False + state_delta: Optional[dict[str, Any]] = None + # for long-running function resume requests (e.g., OAuth callback) + function_call_event_id: Optional[str] = None + # for resume long-running functions + invocation_id: Optional[str] = None + + +class CreateSessionRequest(common.BaseModel): + session_id: Optional[str] = Field( + default=None, + description=( + "The ID of the session to create. If not provided, a random session" + " ID will be generated." + ), + ) + state: Optional[dict[str, Any]] = Field( + default=None, description="The initial state of the session." + ) + events: Optional[list[Event]] = Field( + default=None, + description="A list of events to initialize the session with.", + ) + + +class SaveArtifactRequest(common.BaseModel): + """Request payload for saving a new artifact.""" + + filename: str = Field(description="Artifact filename.") + artifact: types.Part = Field( + description="Artifact payload encoded as google.genai.types.Part." + ) + custom_metadata: Optional[dict[str, Any]] = Field( + default=None, + description="Optional metadata to associate with the artifact version.", + ) + + +class AddSessionToEvalSetRequest(common.BaseModel): + eval_id: str + session_id: str + user_id: str + + +class RunEvalRequest(common.BaseModel): + eval_ids: list[str] = Field( + deprecated=True, + default_factory=list, + description="This field is deprecated, use eval_case_ids instead.", + ) + eval_case_ids: list[str] = Field( + default_factory=list, + description=( + "List of eval case ids to evaluate. if empty, then all eval cases in" + " the eval set are run." + ), + ) + eval_metrics: list[EvalMetric] + + +class UpdateMemoryRequest(common.BaseModel): + """Request to add a session to the memory service.""" + + session_id: str + """The ID of the session to add to memory.""" + + +class UpdateSessionRequest(common.BaseModel): + """Request to update session state without running the agent.""" + + state_delta: dict[str, Any] + """The state changes to apply to the session.""" + + +class RunEvalResult(common.BaseModel): + eval_set_file: str + eval_set_id: str + eval_id: str + final_eval_status: EvalStatus + eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( + deprecated=True, + default=[], + description=( + "This field is deprecated, use overall_eval_metric_results instead." + ), + ) + overall_eval_metric_results: list[EvalMetricResult] + eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation] + user_id: str + session_id: str + + +class RunEvalResponse(common.BaseModel): + run_eval_results: list[RunEvalResult] + + +class GetEventGraphResult(common.BaseModel): + dot_src: str + + +class CreateEvalSetRequest(common.BaseModel): + eval_set: EvalSet + + +class ListEvalSetsResponse(common.BaseModel): + eval_set_ids: list[str] + + +class EvalResult(EvalSetResult): + """This class has no field intentionally. + + The goal here is to just give a new name to the class to align with the API + endpoint. + """ + + +class ListEvalResultsResponse(common.BaseModel): + eval_result_ids: list[str] + + +class ListMetricsInfoResponse(common.BaseModel): + metrics_info: list[MetricInfo] + + +class AppInfo(common.BaseModel): + name: str + root_agent_name: str + description: str + language: Literal["yaml", "python"] + is_computer_use: bool = False + agents: Optional[dict[str, AgentInfo]] = None + + +class ListAppsResponse(common.BaseModel): + apps: list[AppInfo] + + +def _setup_telemetry( + otel_to_cloud: bool = False, + internal_exporters: Optional[list[SpanProcessor]] = None, +): + # TODO - remove the else branch here once maybe_set_otel_providers is no + # longer experimental. + if otel_to_cloud: + _setup_gcp_telemetry(internal_exporters=internal_exporters) + elif _otel_env_vars_enabled(): + _setup_telemetry_from_env(internal_exporters=internal_exporters) + else: + # Old logic - to be removed when above leaves experimental. + tracer_provider = TracerProvider() + if internal_exporters is not None: + for exporter in internal_exporters: + tracer_provider.add_span_processor(exporter) + trace.set_tracer_provider(tracer_provider=tracer_provider) + + +def _otel_env_vars_enabled() -> bool: + return any([ + os.getenv(endpoint_var) + for endpoint_var in [ + otel_env.OTEL_EXPORTER_OTLP_ENDPOINT, + otel_env.OTEL_EXPORTER_OTLP_TRACES_ENDPOINT, + otel_env.OTEL_EXPORTER_OTLP_METRICS_ENDPOINT, + otel_env.OTEL_EXPORTER_OTLP_LOGS_ENDPOINT, + ] + ]) + + +def _setup_gcp_telemetry( + internal_exporters: list[SpanProcessor] = None, +): + if typing.TYPE_CHECKING: + from ..telemetry.setup import OTelHooks + + otel_hooks_to_add: list[OTelHooks] = [] + + if internal_exporters: + from ..telemetry.setup import OTelHooks + + # Register ADK-specific exporters in trace provider. + otel_hooks_to_add.append(OTelHooks(span_processors=internal_exporters)) + + import google.auth + + from ..telemetry.google_cloud import get_gcp_exporters + from ..telemetry.google_cloud import get_gcp_resource + from ..telemetry.setup import maybe_set_otel_providers + + credentials, project_id = google.auth.default() + + otel_hooks_to_add.append( + get_gcp_exporters( + # TODO - use trace_to_cloud here as well once otel_to_cloud is no + # longer experimental. + enable_cloud_tracing=True, + # TODO - re-enable metrics once errors during shutdown are fixed. + enable_cloud_metrics=False, + enable_cloud_logging=True, + google_auth=(credentials, project_id), + ) + ) + otel_resource = get_gcp_resource(project_id) + + maybe_set_otel_providers( + otel_hooks_to_setup=otel_hooks_to_add, + otel_resource=otel_resource, + ) + _setup_instrumentation_lib_if_installed() + + +def _setup_telemetry_from_env( + internal_exporters: list[SpanProcessor] = None, +): + from ..telemetry.setup import maybe_set_otel_providers + + otel_hooks_to_add = [] + + if internal_exporters: + from ..telemetry.setup import OTelHooks + + # Register ADK-specific exporters in trace provider. + otel_hooks_to_add.append(OTelHooks(span_processors=internal_exporters)) + + maybe_set_otel_providers(otel_hooks_to_setup=otel_hooks_to_add) + _setup_instrumentation_lib_if_installed() + + +def _setup_instrumentation_lib_if_installed(): + # Set instrumentation to enable emitting OTel data from GenAISDK + # Currently the instrumentation lib is in extras dependencies, make sure to + # warn the user if it's not installed. + try: + from opentelemetry.instrumentation.google_genai import GoogleGenAiSdkInstrumentor + + GoogleGenAiSdkInstrumentor().instrument() + except (ImportError, AttributeError): + logger.warning( + "Unable to import GoogleGenAiSdkInstrumentor - some" + " telemetry will be disabled. Make sure to install google-adk[otel-gcp]" + ) + if os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID"): + # Set up HTTPX and gRPC instrumentation for A2A multi-agent observability. + try: + from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor + + HTTPXClientInstrumentor().instrument() + except (ImportError, AttributeError): + logger.warning( + "telemetry enabled but proceeding without HTTPX instrumentation," + " because google-adk[otel-gcp] has not been installed" + ) + try: + from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient + + GrpcInstrumentorClient().instrument() + except (ImportError, AttributeError): + logger.warning( + "telemetry enabled but proceeding without gRPC instrumentation," + " because google-adk[otel-gcp] has not been installed" + ) + + +class AdkWebServer: + """Helper class for setting up and running the ADK web server on FastAPI. + + You construct this class with all the Services required to run ADK agents and + can then call the get_fast_api_app method to get a FastAPI app instance that + can will use your provided service instances, static assets, and agent loader. + If you pass in a web_assets_dir, the static assets will be served under + /dev-ui in addition to the API endpoints created by default. + + You can add additional API endpoints by modifying the FastAPI app + instance returned by get_fast_api_app as this class exposes the agent runners + and most other bits of state retained during the lifetime of the server. + + Attributes: + agent_loader: An instance of BaseAgentLoader for loading agents. + session_service: An instance of BaseSessionService for managing sessions. + memory_service: An instance of BaseMemoryService for managing memory. + artifact_service: An instance of BaseArtifactService for managing + artifacts. + credential_service: An instance of BaseCredentialService for managing + credentials. + eval_sets_manager: An instance of EvalSetsManager for managing evaluation + sets. + eval_set_results_manager: An instance of EvalSetResultsManager for + managing evaluation set results. + agents_dir: Root directory containing subdirs for agents with those + containing resources (e.g. .env files, eval sets, etc.) for the agents. + extra_plugins: A list of fully qualified names of extra plugins to load. + logo_text: Text to display in the logo of the UI. + logo_image_url: URL of an image to display as logo of the UI. + runners_to_clean: Set of runner names marked for cleanup. + current_app_name_ref: A shared reference to the latest ran app name. + runner_dict: A dict of instantiated runners for each app. + """ + + def __init__( + self, + *, + agent_loader: BaseAgentLoader, + session_service: BaseSessionService, + memory_service: BaseMemoryService, + artifact_service: BaseArtifactService, + credential_service: BaseCredentialService, + eval_sets_manager: EvalSetsManager, + eval_set_results_manager: EvalSetResultsManager, + agents_dir: str, + extra_plugins: Optional[list[str]] = None, + logo_text: Optional[str] = None, + logo_image_url: Optional[str] = None, + url_prefix: Optional[str] = None, + auto_create_session: bool = False, + trigger_sources: Optional[list[str]] = None, + ): + self.agent_loader = agent_loader + self.session_service = session_service + self.memory_service = memory_service + self.artifact_service = artifact_service + self.credential_service = credential_service + self.eval_sets_manager = eval_sets_manager + self.eval_set_results_manager = eval_set_results_manager + self.agents_dir = agents_dir + self.extra_plugins = extra_plugins or [] + self.logo_text = logo_text + self.logo_image_url = logo_image_url + # Internal properties we want to allow being modified from callbacks. + self.runners_to_clean: set[str] = set() + self.current_app_name_ref: SharedValue[str] = SharedValue(value="") + self.runner_dict = {} + self.url_prefix = url_prefix + self.auto_create_session = auto_create_session + self.trigger_sources = trigger_sources + + async def get_runner_async(self, app_name: str) -> Runner: + """Returns the cached runner for the given app.""" + # Handle cleanup + if app_name in self.runners_to_clean: + self.runners_to_clean.remove(app_name) + runner = self.runner_dict.pop(app_name, None) + await cleanup.close_runners(list([runner])) + + # Return cached runner if exists + if app_name in self.runner_dict: + return self.runner_dict[app_name] + + # Create new runner + envs.load_dotenv_for_agent(os.path.basename(app_name), self.agents_dir) + agent_or_app = self.agent_loader.load_agent(app_name) + + # Instantiate extra plugins if configured + extra_plugins_instances = self._instantiate_extra_plugins() + + plugins_yaml_path = os.path.join(self.agents_dir, app_name, "plugins.yaml") + bq_analytics_config = None + if os.path.exists(plugins_yaml_path): + with open(plugins_yaml_path, "r", encoding="utf-8") as f: + plugins_config = yaml.safe_load(f) + if plugins_config and isinstance(plugins_config, dict): + bq_analytics_config = plugins_config.get("bigquery_agent_analytics") + + # All YAML agents are treated as visual builder agents. + is_visual_builder_agent = os.path.exists( + os.path.join(self.agents_dir, app_name, "root_agent.yaml") + ) + + if isinstance(agent_or_app, BaseAgent): + plugins = extra_plugins_instances + + # Handle BigQuery Analytics Plugin injection + if bq_analytics_config and all([ + bq_analytics_config.get("project_id"), + bq_analytics_config.get("dataset_id"), + bq_analytics_config.get("dataset_location"), + ]): + from ..plugins.bigquery_agent_analytics_plugin import BigQueryAgentAnalyticsPlugin + + plugins.append( + BigQueryAgentAnalyticsPlugin( + project_id=bq_analytics_config.get("project_id"), + dataset_id=bq_analytics_config.get("dataset_id"), + table_id=bq_analytics_config.get("table_id"), + location=bq_analytics_config.get("dataset_location"), + ) + ) + + agentic_app = App( + name=app_name, + root_agent=agent_or_app, + plugins=plugins, + ) + else: + # Combine existing plugins with extra plugins + agent_or_app.plugins = agent_or_app.plugins + extra_plugins_instances + agentic_app = agent_or_app + + # If the root agent was loaded from YAML, we treat it as being from Visual Builder + if is_visual_builder_agent: + object.__setattr__(agentic_app, "_is_visual_builder_app", True) + + runner = self._create_runner(agentic_app) + self.runner_dict[app_name] = runner + return runner + + def _get_root_agent(self, agent_or_app: BaseAgent | App) -> BaseAgent: + """Extract root agent from either a BaseAgent or App object.""" + if isinstance(agent_or_app, App): + return agent_or_app.root_agent + return agent_or_app + + def _create_runner(self, agentic_app: App) -> Runner: + """Create a runner with common services.""" + return Runner( + app=agentic_app, + artifact_service=self.artifact_service, + session_service=self.session_service, + memory_service=self.memory_service, + credential_service=self.credential_service, + auto_create_session=self.auto_create_session, + ) + + def _instantiate_extra_plugins(self) -> list[BasePlugin]: + """Instantiate extra plugins from the configured list. + + Returns: + List of instantiated BasePlugin objects. + """ + extra_plugins_instances = [] + for qualified_name in self.extra_plugins: + try: + plugin_obj = self._import_plugin_object(qualified_name) + if isinstance(plugin_obj, BasePlugin): + extra_plugins_instances.append(plugin_obj) + elif issubclass(plugin_obj, BasePlugin): + extra_plugins_instances.append(plugin_obj(name=qualified_name)) + except Exception as e: + logger.error("Failed to load plugin %s: %s", qualified_name, e) + return extra_plugins_instances + + def _import_plugin_object(self, qualified_name: str) -> Any: + """Import a plugin object (class or instance) from a fully qualified name. + + Args: + qualified_name: Fully qualified name (e.g., + 'my_package.my_plugin.MyPlugin') + + Returns: + The imported object, which can be either a class or an instance. + + Raises: + ImportError: If the module cannot be imported. + AttributeError: If the object doesn't exist in the module. + """ + module_name, obj_name = qualified_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) + + def _setup_runtime_config(self, web_assets_dir: str): + """Sets up the runtime config for the web server.""" + # Read existing runtime config file. + runtime_config_path = os.path.join( + web_assets_dir, "assets", "config", "runtime-config.json" + ) + runtime_config = {} + try: + with open(runtime_config_path, "r") as f: + runtime_config = json.load(f) + except FileNotFoundError: + logger.info( + "File not found: %s. A new runtime config file will be created.", + runtime_config_path, + ) + except json.JSONDecodeError: + logger.warning( + "Failed to decode JSON from %s. The file content will be" + " overwritten.", + runtime_config_path, + ) + runtime_config["backendUrl"] = self.url_prefix if self.url_prefix else "" + + # Set custom logo config. + if self.logo_text or self.logo_image_url: + if not self.logo_text or not self.logo_image_url: + raise ValueError( + "Both --logo-text and --logo-image-url must be defined when using" + " logo config." + ) + runtime_config["logo"] = { + "text": self.logo_text, + "imageUrl": self.logo_image_url, + } + elif "logo" in runtime_config: + del runtime_config["logo"] + + # Write the runtime config file. + try: + os.makedirs(os.path.dirname(runtime_config_path), exist_ok=True) + with open(runtime_config_path, "w") as f: + json.dump(runtime_config, f, indent=2) + except IOError as e: + logger.error( + "Failed to write runtime config file %s: %s", runtime_config_path, e + ) + + async def _create_session( + self, + *, + app_name: str, + user_id: str, + session_id: Optional[str] = None, + state: Optional[dict[str, Any]] = None, + ) -> Session: + try: + session = await self.session_service.create_session( + app_name=app_name, + user_id=user_id, + state=state, + session_id=session_id, + ) + logger.info("New session created: %s", session.id) + return session + except AlreadyExistsError as e: + raise HTTPException( + status_code=409, detail=f"Session already exists: {session_id}" + ) from e + except Exception as e: + logger.error( + "Internal server error during session creation: %s", e, exc_info=True + ) + raise HTTPException(status_code=500, detail=str(e)) from e + + def get_fast_api_app( + self, + lifespan: Optional[Lifespan[FastAPI]] = None, + allow_origins: Optional[list[str]] = None, + web_assets_dir: Optional[str] = None, + setup_observer: Callable[ + [Observer, "AdkWebServer"], None + ] = lambda o, s: None, + tear_down_observer: Callable[ + [Observer, "AdkWebServer"], None + ] = lambda o, s: None, + register_processors: Callable[[TracerProvider], None] = lambda o: None, + otel_to_cloud: bool = False, + with_ui: bool = False, + ): + """Creates a FastAPI app for the ADK web server. + + By default it'll just return a FastAPI instance with the API server + endpoints, + but if you specify a web_assets_dir, it'll also serve the static web assets + from that directory. + + Args: + lifespan: The lifespan of the FastAPI app. + allow_origins: The origins that are allowed to make cross-origin requests. + Entries can be literal origins (e.g., 'https://example.com') or regex + patterns prefixed with 'regex:' (e.g., + 'regex:https://.*\\.example\\.com'). + web_assets_dir: The directory containing the web assets to serve. + setup_observer: Callback for setting up the file system observer. + tear_down_observer: Callback for cleaning up the file system observer. + register_processors: Callback for additional Span processors to be added + to the TracerProvider. + otel_to_cloud: Whether to enable Cloud Trace and Cloud Logging + integrations. + + Returns: + A FastAPI app instance. + """ + # Properties we don't need to modify from callbacks + trace_dict = {} + session_trace_dict = {} + # Set up a file system watcher to detect changes in the agents directory. + observer = Observer() + setup_observer(observer, self) + + @asynccontextmanager + async def internal_lifespan(app: FastAPI): + try: + if lifespan: + async with lifespan(app) as lifespan_context: + yield lifespan_context + else: + yield + finally: + tear_down_observer(observer, self) + # Create tasks for all runner closures to run concurrently + await cleanup.close_runners(list(self.runner_dict.values())) + + memory_exporter = InMemoryExporter(session_trace_dict) + + _setup_telemetry( + otel_to_cloud=otel_to_cloud, + internal_exporters=[ + export_lib.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict)), + export_lib.SimpleSpanProcessor(memory_exporter), + ], + ) + if web_assets_dir: + self._setup_runtime_config(web_assets_dir) + + # TODO - register_processors to be removed once --otel_to_cloud is no + # longer experimental. + tracer_provider = trace.get_tracer_provider() + register_processors(tracer_provider) + + # Run the FastAPI server. + app = FastAPI(lifespan=internal_lifespan) + + has_configured_allowed_origins = bool(allow_origins) + if allow_origins: + literal_origins, combined_regex = _parse_cors_origins(allow_origins) + compiled_origin_regex = ( + re.compile(combined_regex) if combined_regex is not None else None + ) + app.add_middleware( + CORSMiddleware, + allow_origins=literal_origins, + allow_origin_regex=combined_regex, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + else: + literal_origins = [] + compiled_origin_regex = None + + app.add_middleware( + _OriginCheckMiddleware, + has_configured_allowed_origins=has_configured_allowed_origins, + allowed_origins=literal_origins, + allowed_origin_regex=compiled_origin_regex, + ) + + @app.get("/health") + async def health() -> dict[str, str]: + return {"status": "ok"} + + @app.get("/version") + async def version() -> dict[str, str]: + return { + "version": __version__, + "language": "python", + "language_version": ( + f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ), + } + + @app.get("/list-apps") + async def list_apps( + detailed: bool = Query( + default=False, description="Return detailed app information" + ) + ) -> list[str] | ListAppsResponse: + if detailed: + apps_info = self.agent_loader.list_agents_detailed() + return ListAppsResponse(apps=[AppInfo(**app) for app in apps_info]) + return self.agent_loader.list_agents() + + @experimental + @app.get("/apps/{app_name}/app-info", response_model_exclude_none=True) + async def get_adk_app_info(app_name: str) -> AppInfo: + """Returns the detailed info for a given ADK app.""" + agent_or_app = self.agent_loader.load_agent(app_name) + root_agent = self._get_root_agent(agent_or_app) + if isinstance(root_agent, LlmAgent): + return AppInfo( + name=app_name, + root_agent_name=root_agent.name, + description=root_agent.description, + language="python", + agents=get_agents_dict(root_agent), + ) + else: + raise HTTPException( + status_code=400, detail="Root agent is not an LlmAgent" + ) + + @app.get("/debug/trace/{event_id}", tags=[TAG_DEBUG]) + async def get_trace_dict(event_id: str) -> Any: + event_dict = trace_dict.get(event_id, None) + if event_dict is None: + raise HTTPException(status_code=404, detail="Trace not found") + return event_dict + + if web_assets_dir: + + @app.get("/dev/build_graph/{app_name}") + async def get_app_info(app_name: str) -> Any: + runner = await self.get_runner_async(app_name) + + if not runner.app: + raise HTTPException( + status_code=404, detail=f"App not found: {app_name}" + ) + + def serialize_agent(agent: BaseAgent) -> dict[str, Any]: + """Recursively serialize an agent, excluding non-serializable fields.""" + agent_dict = {} + + for field_name, field_info in agent.__class__.model_fields.items(): + # Skip non-serializable fields + if field_name in [ + "parent_agent", + "before_agent_callback", + "after_agent_callback", + "before_model_callback", + "after_model_callback", + "on_model_error_callback", + "before_tool_callback", + "after_tool_callback", + "on_tool_error_callback", + ]: + continue + + value = getattr(agent, field_name, None) + + # Handle sub_agents recursively + if field_name == "sub_agents" and value: + agent_dict[field_name] = [ + serialize_agent(sub_agent) for sub_agent in value + ] + elif value is None or field_name == "tools": + continue + else: + try: + if isinstance(value, (str, int, float, bool, list, dict)): + agent_dict[field_name] = value + elif hasattr(value, "model_dump"): + agent_dict[field_name] = value.model_dump( + mode="python", exclude_none=True + ) + else: + agent_dict[field_name] = str(value) + except Exception: + pass + + return agent_dict + + app_info = { + "name": runner.app.name, + "root_agent": serialize_agent(runner.app.root_agent), + } + + # Add optional fields if present + if runner.app.plugins: + app_info["plugins"] = [ + {"name": getattr(plugin, "name", type(plugin).__name__)} + for plugin in runner.app.plugins + ] + + if runner.app.context_cache_config: + try: + app_info["context_cache_config"] = ( + runner.app.context_cache_config.model_dump( + mode="python", exclude_none=True + ) + ) + except Exception: + pass + + if runner.app.resumability_config: + try: + app_info["resumability_config"] = ( + runner.app.resumability_config.model_dump( + mode="python", exclude_none=True + ) + ) + except Exception: + pass + + return app_info + + @app.get("/debug/trace/session/{session_id}", tags=[TAG_DEBUG]) + async def get_session_trace(session_id: str) -> Any: + spans = memory_exporter.get_finished_spans(session_id) + if not spans: + return [] + return [ + { + "name": s.name, + "span_id": s.context.span_id, + "trace_id": s.context.trace_id, + "start_time": s.start_time, + "end_time": s.end_time, + "attributes": dict(s.attributes), + "parent_span_id": s.parent.span_id if s.parent else None, + } + for s in spans + ] + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def get_session( + app_name: str, user_id: str, session_id: str + ) -> Session: + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + self.current_app_name_ref.value = app_name + return session + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions", + response_model_exclude_none=True, + ) + async def list_sessions(app_name: str, user_id: str) -> list[Session]: + list_sessions_response = await self.session_service.list_sessions( + app_name=app_name, user_id=user_id + ) + return [ + session + for session in list_sessions_response.sessions + # Remove sessions that were generated as a part of Eval. + if not session.id.startswith(EVAL_SESSION_ID_PREFIX) + ] + + @deprecated( + "Please use create_session instead. This will be removed in future" + " releases." + ) + @app.post( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def create_session_with_id( + app_name: str, + user_id: str, + session_id: str, + state: Optional[dict[str, Any]] = None, + ) -> Session: + return await self._create_session( + app_name=app_name, + user_id=user_id, + state=state, + session_id=session_id, + ) + + @app.post( + "/apps/{app_name}/users/{user_id}/sessions", + response_model_exclude_none=True, + ) + async def create_session( + app_name: str, + user_id: str, + req: Optional[CreateSessionRequest] = None, + ) -> Session: + if not req: + return await self._create_session(app_name=app_name, user_id=user_id) + + session = await self._create_session( + app_name=app_name, + user_id=user_id, + state=req.state, + session_id=req.session_id, + ) + + if req.events: + for event in req.events: + await self.session_service.append_event(session=session, event=event) + + return session + + @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") + async def delete_session( + app_name: str, user_id: str, session_id: str + ) -> None: + await self.session_service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + @app.patch( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def update_session( + app_name: str, + user_id: str, + session_id: str, + req: UpdateSessionRequest, + ) -> Session: + """Updates session state without running the agent. + + Args: + app_name: The name of the application. + user_id: The ID of the user. + session_id: The ID of the session to update. + req: The patch request containing state changes. + + Returns: + The updated session. + + Raises: + HTTPException: If the session is not found. + """ + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + # Create an event to record the state change + import uuid + + from ..events.event import Event + from ..events.event import EventActions + + state_update_event = Event( + invocation_id="p-" + str(uuid.uuid4()), + author="user", + actions=EventActions(state_delta=req.state_delta), + ) + + # Append the event to the session + # This will automatically update the session state through __update_session_state + await self.session_service.append_event( + session=session, event=state_update_event + ) + + return session + + @app.post( + "/apps/{app_name}/eval-sets", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def create_eval_set( + app_name: str, create_eval_set_request: CreateEvalSetRequest + ) -> EvalSet: + try: + return self.eval_sets_manager.create_eval_set( + app_name=app_name, + eval_set_id=create_eval_set_request.eval_set.eval_set_id, + ) + except ValueError as ve: + raise HTTPException( + status_code=400, + detail=str(ve), + ) from ve + + # TODO - remove after migration + @deprecated( + "Please use create_eval_set instead. This will be removed in future" + " releases." + ) + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def create_eval_set_legacy( + app_name: str, + eval_set_id: str, + ): + """Creates an eval set, given the id.""" + await create_eval_set( + app_name=app_name, + create_eval_set_request=CreateEvalSetRequest( + eval_set=EvalSet(eval_set_id=eval_set_id, eval_cases=[]) + ), + ) + + # TODO - remove after migration + @deprecated( + "Please use list_eval_sets instead. This will be removed in future" + " releases." + ) + @app.get( + "/apps/{app_name}/eval_sets", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def list_eval_sets_legacy(app_name: str) -> list[str]: + list_eval_sets_response = await list_eval_sets(app_name) + return list_eval_sets_response.eval_set_ids + + # TODO - remove after migration + @deprecated( + "Please use run_eval instead. This will be removed in future releases." + ) + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def run_eval_legacy( + app_name: str, eval_set_id: str, req: RunEvalRequest + ) -> list[RunEvalResult]: + run_eval_response = await run_eval( + app_name=app_name, eval_set_id=eval_set_id, req=req + ) + return run_eval_response.run_eval_results + + # TODO - remove after migration + @deprecated( + "Please use get_eval_result instead. This will be removed in future" + " releases." + ) + @app.get( + "/apps/{app_name}/eval_results/{eval_result_id}", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def get_eval_result_legacy( + app_name: str, + eval_result_id: str, + ) -> EvalSetResult: + try: + return self.eval_set_results_manager.get_eval_set_result( + app_name, eval_result_id + ) + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) from ve + except ValidationError as ve: + raise HTTPException(status_code=500, detail=str(ve)) from ve + + # TODO - remove after migration + @deprecated( + "Please use list_eval_results instead. This will be removed in future" + " releases." + ) + @app.get( + "/apps/{app_name}/eval_results", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def list_eval_results_legacy(app_name: str) -> list[str]: + list_eval_results_response = await list_eval_results(app_name) + return list_eval_results_response.eval_result_ids + + @app.get( + "/apps/{app_name}/eval-sets", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def list_eval_sets(app_name: str) -> ListEvalSetsResponse: + """Lists all eval sets for the given app.""" + eval_sets = [] + try: + eval_sets = self.eval_sets_manager.list_eval_sets(app_name) + except NotFoundError as e: + logger.warning(e) + + return ListEvalSetsResponse(eval_set_ids=eval_sets) + + @app.post( + "/apps/{app_name}/eval-sets/{eval_set_id}/add-session", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}/add_session", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def add_session_to_eval_set( + app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest + ): + # Get the session + session = await self.session_service.get_session( + app_name=app_name, user_id=req.user_id, session_id=req.session_id + ) + assert session, "Session not found." + + # Convert the session data to eval invocations + invocations = evals.convert_session_to_eval_invocations(session) + + # Populate the session with initial session state. + agent_or_app = self.agent_loader.load_agent(app_name) + root_agent = self._get_root_agent(agent_or_app) + initial_session_state = create_empty_state(root_agent) + + new_eval_case = EvalCase( + eval_id=req.eval_id, + conversation=invocations, + session_input=SessionInput( + app_name=app_name, + user_id=req.user_id, + state=initial_session_state, + ), + creation_timestamp=time.time(), + ) + + try: + self.eval_sets_manager.add_eval_case( + app_name, eval_set_id, new_eval_case + ) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) from ve + + @app.get( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def list_evals_in_eval_set( + app_name: str, + eval_set_id: str, + ) -> list[str]: + """Lists all evals in an eval set.""" + eval_set_data = self.eval_sets_manager.get_eval_set(app_name, eval_set_id) + + if not eval_set_data: + raise HTTPException( + status_code=400, detail=f"Eval set `{eval_set_id}` not found." + ) + + return sorted([x.eval_id for x in eval_set_data.eval_cases]) + + @app.get( + "/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + @app.get( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def get_eval( + app_name: str, eval_set_id: str, eval_case_id: str + ) -> EvalCase: + """Gets an eval case in an eval set.""" + eval_case_to_find = self.eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + if eval_case_to_find: + return eval_case_to_find + + raise HTTPException( + status_code=404, + detail=( + f"Eval set `{eval_set_id}` or Eval `{eval_case_id}` not found." + ), + ) + + @app.put( + "/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + @app.put( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def update_eval( + app_name: str, + eval_set_id: str, + eval_case_id: str, + updated_eval_case: EvalCase, + ): + if ( + updated_eval_case.eval_id + and updated_eval_case.eval_id != eval_case_id + ): + raise HTTPException( + status_code=400, + detail=( + "Eval id in EvalCase should match the eval id in the API route." + ), + ) + + # Overwrite the value. We are either overwriting the same value or an empty + # field. + updated_eval_case.eval_id = eval_case_id + try: + self.eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + except NotFoundError as nfe: + raise HTTPException(status_code=404, detail=str(nfe)) from nfe + + @app.delete( + "/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}", + tags=[TAG_EVALUATION], + ) + @app.delete( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", + tags=[TAG_EVALUATION], + ) + async def delete_eval( + app_name: str, eval_set_id: str, eval_case_id: str + ) -> None: + try: + self.eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + except NotFoundError as nfe: + raise HTTPException(status_code=404, detail=str(nfe)) from nfe + + @app.post( + "/apps/{app_name}/eval-sets/{eval_set_id}/run", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def run_eval( + app_name: str, eval_set_id: str, req: RunEvalRequest + ) -> RunEvalResponse: + """Runs an eval given the details in the eval request.""" + # Create a mapping from eval set file to all the evals that needed to be + # run. + try: + from ..evaluation.local_eval_service import LocalEvalService + from .cli_eval import _collect_eval_results + from .cli_eval import _collect_inferences + + eval_set = self.eval_sets_manager.get_eval_set(app_name, eval_set_id) + + if not eval_set: + raise HTTPException( + status_code=400, detail=f"Eval set `{eval_set_id}` not found." + ) + + agent_or_app = self.agent_loader.load_agent(app_name) + root_agent = self._get_root_agent(agent_or_app) + + eval_case_results = [] + + eval_service = LocalEvalService( + root_agent=root_agent, + eval_sets_manager=self.eval_sets_manager, + eval_set_results_manager=self.eval_set_results_manager, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + inference_request = InferenceRequest( + app_name=app_name, + eval_set_id=eval_set.eval_set_id, + eval_case_ids=req.eval_case_ids or req.eval_ids, + inference_config=InferenceConfig(), + ) + inference_results = await _collect_inferences( + inference_requests=[inference_request], eval_service=eval_service + ) + + eval_case_results = await _collect_eval_results( + inference_results=inference_results, + eval_service=eval_service, + eval_metrics=req.eval_metrics, + ) + except ModuleNotFoundError as e: + logger.exception("%s", e) + raise HTTPException( + status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE + ) from e + + run_eval_results = [] + for eval_case_result in eval_case_results: + run_eval_results.append( + RunEvalResult( + eval_set_file=eval_case_result.eval_set_file, + eval_set_id=eval_set_id, + eval_id=eval_case_result.eval_id, + final_eval_status=eval_case_result.final_eval_status, + overall_eval_metric_results=eval_case_result.overall_eval_metric_results, + eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, + user_id=eval_case_result.user_id, + session_id=eval_case_result.session_id, + ) + ) + + return RunEvalResponse(run_eval_results=run_eval_results) + + @app.get( + "/apps/{app_name}/eval-results/{eval_result_id}", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def get_eval_result( + app_name: str, + eval_result_id: str, + ) -> EvalResult: + """Gets the eval result for the given eval id.""" + try: + eval_set_result = self.eval_set_results_manager.get_eval_set_result( + app_name, eval_result_id + ) + return EvalResult(**eval_set_result.model_dump()) + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) from ve + except ValidationError as ve: + raise HTTPException(status_code=500, detail=str(ve)) from ve + + @app.get( + "/apps/{app_name}/eval-results", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def list_eval_results(app_name: str) -> ListEvalResultsResponse: + """Lists all eval results for the given app.""" + eval_result_ids = self.eval_set_results_manager.list_eval_set_results( + app_name + ) + return ListEvalResultsResponse(eval_result_ids=eval_result_ids) + + @app.get( + "/apps/{app_name}/metrics-info", + response_model_exclude_none=True, + tags=[TAG_EVALUATION], + ) + async def list_metrics_info(app_name: str) -> ListMetricsInfoResponse: + """Lists all eval metrics for the given app.""" + try: + from ..evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY + + # Right now we ignore the app_name as eval metrics are not tied to the + # app_name, but they could be moving forward. + metrics_info = ( + DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics() + ) + return ListMetricsInfoResponse(metrics_info=metrics_info) + except ModuleNotFoundError as e: + logger.exception("%s\n%s", MISSING_EVAL_DEPENDENCIES_MESSAGE, e) + raise HTTPException( + status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE + ) from e + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", + response_model_exclude_none=True, + ) + async def load_artifact( + app_name: str, + user_id: str, + session_id: str, + artifact_name: str, + version: Optional[int] = Query(None), + ) -> Optional[types.Part]: + artifact = await self.artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + version=version, + ) + if not artifact: + raise HTTPException(status_code=404, detail="Artifact not found") + return artifact + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/metadata", + response_model=list[ArtifactVersion], + response_model_exclude_none=True, + ) + async def list_artifact_versions_metadata( + app_name: str, + user_id: str, + session_id: str, + artifact_name: str, + ) -> list[ArtifactVersion]: + return await self.artifact_service.list_artifact_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + ) + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", + response_model_exclude_none=True, + ) + async def load_artifact_version( + app_name: str, + user_id: str, + session_id: str, + artifact_name: str, + version_id: int, + ) -> Optional[types.Part]: + artifact = await self.artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + version=version_id, + ) + if not artifact: + raise HTTPException(status_code=404, detail="Artifact not found") + return artifact + + @app.post( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", + response_model=ArtifactVersion, + response_model_exclude_none=True, + ) + async def save_artifact( + app_name: str, + user_id: str, + session_id: str, + req: SaveArtifactRequest, + ) -> ArtifactVersion: + try: + version = await self.artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=req.filename, + artifact=req.artifact, + custom_metadata=req.custom_metadata, + ) + except InputValidationError as ive: + raise HTTPException(status_code=400, detail=str(ive)) from ive + except Exception as exc: # pylint: disable=broad-exception-caught + logger.error( + "Internal error while saving artifact %s for app=%s user=%s" + " session=%s: %s", + req.filename, + app_name, + user_id, + session_id, + exc, + exc_info=True, + ) + raise HTTPException(status_code=500, detail=str(exc)) from exc + artifact_version = await self.artifact_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=req.filename, + version=version, + ) + if artifact_version is None: + raise HTTPException( + status_code=500, detail="Artifact metadata unavailable" + ) + return artifact_version + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}/metadata", + response_model=ArtifactVersion, + response_model_exclude_none=True, + ) + async def get_artifact_version_metadata( + app_name: str, + user_id: str, + session_id: str, + artifact_name: str, + version_id: int, + ) -> ArtifactVersion: + artifact_version = await self.artifact_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + version=version_id, + ) + if not artifact_version: + raise HTTPException( + status_code=404, detail="Artifact version not found" + ) + return artifact_version + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", + response_model_exclude_none=True, + ) + async def list_artifact_names( + app_name: str, user_id: str, session_id: str + ) -> list[str]: + return await self.artifact_service.list_artifact_keys( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", + response_model_exclude_none=True, + ) + async def list_artifact_versions( + app_name: str, user_id: str, session_id: str, artifact_name: str + ) -> list[int]: + return await self.artifact_service.list_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + ) + + @app.delete( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", + ) + async def delete_artifact( + app_name: str, user_id: str, session_id: str, artifact_name: str + ) -> None: + await self.artifact_service.delete_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + ) + + @app.patch("/apps/{app_name}/users/{user_id}/memory") + async def patch_memory( + app_name: str, user_id: str, update_memory_request: UpdateMemoryRequest + ) -> None: + """Adds all events from a given session to the memory service. + + Args: + app_name: The name of the application. + user_id: The ID of the user. + update_memory_request: The memory request for the update + + Raises: + HTTPException: If the memory service is not configured or the request + is invalid. + """ + if not self.memory_service: + raise HTTPException( + status_code=400, detail="Memory service is not configured." + ) + if ( + update_memory_request is None + or update_memory_request.session_id is None + ): + raise HTTPException( + status_code=400, detail="Update memory request is invalid." + ) + + session = await self.session_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=update_memory_request.session_id, + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + await self.memory_service.add_session_to_memory(session) + + def _set_telemetry_context_if_needed(runner: Runner): + """Helper to set contextvars for the current request task.""" + app = getattr(runner, "app", None) + from ..utils._telemetry_context import _is_visual_builder + + if app and getattr(app, "_is_visual_builder_app", False): + _is_visual_builder.set(True) + else: + _is_visual_builder.set(False) + + @app.post("/run", response_model_exclude_none=True) + async def run_agent(req: RunAgentRequest) -> list[Event]: + self.current_app_name_ref.value = req.app_name + runner = await self.get_runner_async(req.app_name) + _set_telemetry_context_if_needed(runner) + try: + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + invocation_id=req.invocation_id, + ) + ) as agen: + events = [event async for event in agen] + except SessionNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + logger.info("Generated %s events in agent run", len(events)) + logger.debug("Events generated: %s", events) + return events + + @app.post("/run_sse") + async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: + self.current_app_name_ref.value = req.app_name + stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE + runner = await self.get_runner_async(req.app_name) + _set_telemetry_context_if_needed(runner) + + # Validate session existence before starting the stream. + # We check directly here instead of eagerly advancing the + # runner's async generator with anext(), because splitting + # generator consumption across two asyncio Tasks (request + # handler vs StreamingResponse) breaks OpenTelemetry context + # detachment. + if not runner.auto_create_session: + session = await self.session_service.get_session( + app_name=req.app_name, + user_id=req.user_id, + session_id=req.session_id, + ) + if not session: + raise HTTPException( + status_code=404, + detail=f"Session not found: {req.session_id}", + ) + + # Convert the events to properly formatted SSE + async def event_generator(): + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig(streaming_mode=stream_mode), + invocation_id=req.invocation_id, + ) + ) as agen: + try: + async for event in agen: + # ADK Web renders artifacts from `actions.artifactDelta` + # during part processing *and* during action processing + # 1) the original event with `artifactDelta` cleared (content) + # 2) a content-less "action-only" event carrying `artifactDelta` + events_to_stream = [event] + if ( + not req.function_call_event_id + and event.actions.artifact_delta + and event.content + and event.content.parts + ): + content_event = event.model_copy(deep=True) + content_event.actions.artifact_delta = {} + artifact_event = event.model_copy(deep=True) + artifact_event.content = None + events_to_stream = [content_event, artifact_event] + + for event_to_stream in events_to_stream: + sse_event = event_to_stream.model_dump_json( + exclude_none=True, + by_alias=True, + ) + logger.debug( + "Generated event in agent run streaming: %s", sse_event + ) + yield f"data: {sse_event}\n\n" + except Exception as e: + logger.exception("Error in event_generator: %s", e) + yield f"data: {json.dumps({'error': str(e)})}\n\n" + + # Returns a streaming response with the proper media type for SSE + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + ) + + @app.get( + "/dev/{app_name}/graph", + response_model_exclude_none=True, + tags=[TAG_DEBUG], + ) + async def get_app_graph_dot( + app_name: str, dark_mode: bool = False + ) -> GetEventGraphResult | dict: + """Returns the base agent graph in DOT format without any highlights. + + This endpoint allows the frontend to fetch the graph structure once + and compute highlights client-side for better performance. + + Args: + app_name: The name of the agent/app + dark_mode: Whether to use dark theme background color + """ + agent_or_app = self.agent_loader.load_agent(app_name) + root_agent = self._get_root_agent(agent_or_app) + + # Get graph with NO highlights (empty list) and specified theme + dot_graph = await agent_graph.get_agent_graph( + root_agent, [], dark_mode=dark_mode + ) + + if dot_graph and isinstance(dot_graph, graphviz.Digraph): + return GetEventGraphResult(dot_src=dot_graph.source) + else: + return {} + + # TODO: This endpoint can be removed once we update adk web to stop consuming it + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph", + response_model_exclude_none=True, + tags=[TAG_DEBUG], + ) + async def get_event_graph( + app_name: str, user_id: str, session_id: str, event_id: str + ): + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + session_events = session.events if session else [] + event = next((x for x in session_events if x.id == event_id), None) + if not event: + return {} + + function_calls = event.get_function_calls() + function_responses = event.get_function_responses() + agent_or_app = self.agent_loader.load_agent(app_name) + root_agent = self._get_root_agent(agent_or_app) + dot_graph = None + if function_calls: + function_call_highlights = [] + for function_call in function_calls: + from_name = event.author + to_name = function_call.name + function_call_highlights.append((from_name, to_name)) + dot_graph = await agent_graph.get_agent_graph( + root_agent, function_call_highlights + ) + elif function_responses: + function_responses_highlights = [] + for function_response in function_responses: + from_name = function_response.name + to_name = event.author + function_responses_highlights.append((from_name, to_name)) + dot_graph = await agent_graph.get_agent_graph( + root_agent, function_responses_highlights + ) + else: + from_name = event.author + to_name = "" + dot_graph = await agent_graph.get_agent_graph( + root_agent, [(from_name, to_name)] + ) + if dot_graph and isinstance(dot_graph, graphviz.Digraph): + return GetEventGraphResult(dot_src=dot_graph.source) + else: + return {} + + @app.websocket("/run_live") + async def run_agent_live( + websocket: WebSocket, + app_name: str, + user_id: str, + session_id: str, + modalities: List[Literal["TEXT", "AUDIO"]] = Query( + default=["AUDIO"] + ), # Only allows "TEXT" or "AUDIO" + proactive_audio: bool | None = Query(default=None), + enable_affective_dialog: bool | None = Query(default=None), + enable_session_resumption: bool | None = Query(default=None), + save_live_blob: bool = Query(default=False), + ) -> None: + ws_origin = websocket.headers.get("origin") + if ws_origin is not None and not _is_request_origin_allowed( + ws_origin, + websocket.scope, + literal_origins, + compiled_origin_regex, + has_configured_allowed_origins, + ): + await websocket.close(code=1008, reason="Origin not allowed") + return + + await websocket.accept() + self.current_app_name_ref.value = app_name + runner_for_context = await self.get_runner_async(app_name) + _set_telemetry_context_if_needed(runner_for_context) + + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + await websocket.close(code=1002, reason="Session not found") + return + + live_request_queue = LiveRequestQueue() + + async def forward_events(): + runner = await self.get_runner_async(app_name) + run_config = RunConfig( + response_modalities=modalities, + proactivity=( + types.ProactivityConfig(proactive_audio=proactive_audio) + if proactive_audio is not None + else None + ), + enable_affective_dialog=enable_affective_dialog, + session_resumption=( + types.SessionResumptionConfig( + transparent=enable_session_resumption + ) + if enable_session_resumption is not None + else None + ), + save_live_blob=save_live_blob, + ) + async with Aclosing( + runner.run_live( + session=session, + live_request_queue=live_request_queue, + run_config=run_config, + ) + ) as agen: + async for event in agen: + await websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + + async def process_messages(): + try: + while True: + data = await websocket.receive_text() + # Validate and send the received message to the live queue. + live_request_queue.send(LiveRequest.model_validate_json(data)) + except ValidationError as ve: + logger.error("Validation error in process_messages: %s", ve) + + # Run both tasks concurrently and cancel all if one fails. + tasks = [ + asyncio.create_task(forward_events()), + asyncio.create_task(process_messages()), + ] + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_EXCEPTION + ) + try: + # This will re-raise any exception from the completed tasks. + for task in done: + task.result() + except WebSocketDisconnect: + # Disconnection could happen when receive or send text via websocket + logger.info("Client disconnected during live session.") + except Exception as e: + logger.exception("Error during live websocket communication: %s", e) + traceback.print_exc() + WEBSOCKET_INTERNAL_ERROR_CODE = 1011 + WEBSOCKET_MAX_BYTES_FOR_REASON = 123 + await websocket.close( + code=WEBSOCKET_INTERNAL_ERROR_CODE, + reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON], + ) + finally: + for task in pending: + task.cancel() + + # Register /trigger/* endpoints when enabled. + if self.trigger_sources: + from .trigger_routes import TriggerRouter + + trigger_router = TriggerRouter(self, trigger_sources=self.trigger_sources) + trigger_router.register(app) + + if web_assets_dir: + import mimetypes + + mimetypes.add_type("application/javascript", ".js", True) + mimetypes.add_type("text/javascript", ".js", True) + + redirect_dev_ui_url = ( + self.url_prefix + "/dev-ui/" if self.url_prefix else "/dev-ui/" + ) + + @app.get("/dev-ui/config") + async def get_ui_config(): + return { + "logo_text": self.logo_text, + "logo_image_url": self.logo_image_url, + } + + @app.get("/") + async def redirect_root_to_dev_ui(): + return RedirectResponse(redirect_dev_ui_url) + + @app.get("/dev-ui") + async def redirect_dev_ui_add_slash(): + return RedirectResponse(redirect_dev_ui_url) + + app.mount( + "/dev-ui/", + StaticFiles(directory=web_assets_dir, html=True, follow_symlink=True), + name="static", + ) + + return app diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 1e8ce5dfd1..8ee0e16b92 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -29,7 +29,6 @@ import click from packaging.version import parse -from ..version import __version__ from .utils import _onboarding _IS_WINDOWS = os.name == 'nt' @@ -63,14 +62,18 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: with open(requirements_txt_path, 'a', encoding='utf-8') as f: if requirements and not requirements.endswith('\n'): f.write('\n') - f.write('google-cloud-aiplatform[agent_engines]\n') - f.write(f'google-adk=={__version__}\n') + f.write(_AGENT_ENGINE_REQUIREMENT + '\n') _DOCKERFILE_TEMPLATE: Final[str] = """ FROM python:3.11-slim WORKDIR /app +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y git && \ + apt -y autoremove + # Create a non-root user RUN adduser --disabled-password --gecos "" myuser @@ -87,7 +90,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: # Set up environment variables - End # Install ADK - Start -RUN pip install google-adk=={adk_version} +# RUN pip install google-adk=={adk_version} # Install ADK - End # Copy agent - Start @@ -103,34 +106,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: EXPOSE {port} -CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} {otel_to_cloud_option} {allow_origins_option} {a2a_option} {trigger_sources_option} "/app/agents" -""" - -_AGENT_ENGINE_APP_TEMPLATE: Final[str] = """ -import os -import vertexai -from vertexai.agent_engines import AdkApp -{extra_imports} - -if {is_config_agent}: - from google.adk.agents import config_agent_utils - config_path = os.path.join(os.path.dirname(__file__), "root_agent.yaml") - root_agent = config_agent_utils.from_config(config_path) -else: - from .agent import {adk_app_object} - -if {express_mode}: # Whether or not to use Express Mode - vertexai.init(api_key=os.environ.get("GOOGLE_API_KEY")) -else: - vertexai.init( - project=os.environ.get("GOOGLE_CLOUD_PROJECT"), - location=os.environ.get("GOOGLE_CLOUD_LOCATION"), - ) - -adk_app = AdkApp( - {app_instantiation}, - enable_tracing={trace_to_cloud_option}, -) +CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} {otel_to_cloud_option} {allow_origins_option} {a2a_option} {trigger_sources_option} {gemini_enterprise_option} "/app/agents" """ _AGENT_ENGINE_CLASS_METHODS = [ @@ -413,6 +389,13 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: ] +def _resolve_adk_version() -> str: + """Returns the default ADK version.""" + from google.adk.version import __version__ + + return __version__ + + def _resolve_project(project_in_option: Optional[str]) -> str: if project_in_option: return project_in_option @@ -723,7 +706,7 @@ def to_cloud_run( gcp_region=region, app_name=app_name, port=port, - command='api_server --with_ui' if with_ui else 'api_server', + command='web' if with_ui else 'api_server', install_agent_deps=install_agent_deps, service_option=_get_service_option_by_adk_version( adk_version, @@ -739,6 +722,7 @@ def to_cloud_run( host_option=host_option, a2a_option=a2a_option, trigger_sources_option=trigger_sources_option, + gemini_enterprise_option='', ) dockerfile_path = os.path.join(temp_folder, 'Dockerfile') os.makedirs(temp_folder, exist_ok=True) @@ -829,7 +813,7 @@ def to_agent_engine( *, agent_folder: str, temp_folder: Optional[str] = None, - adk_app: str, + adk_app: Optional[str] = None, staging_bucket: Optional[str] = None, trace_to_cloud: Optional[bool] = None, otel_to_cloud: Optional[bool] = None, @@ -845,6 +829,9 @@ def to_agent_engine( env_file: Optional[str] = None, agent_engine_config_file: Optional[str] = None, skip_agent_import_validation: bool = True, + trigger_sources: Optional[str] = None, + artifact_service_uri: Optional[str] = None, + adk_version: Optional[str] = None, ): """Deploys an agent to Vertex AI Agent Engine. @@ -852,44 +839,31 @@ def to_agent_engine( - __init__.py - agent.py - - .py (optional, for customization; will be autogenerated otherwise) - requirements.txt (optional, for additional dependencies) - .env (optional, for environment variables) - ... (other required source files) - The contents of `adk_app` should look something like: - - ``` - from agent import - from vertexai.agent_engines import AdkApp - - adk_app = AdkApp( - agent=, # or `app=` - ) - ``` - Args: agent_folder (str): The folder (absolute path) containing the agent source code. temp_folder (str): The temp folder for the generated Agent Engine source files. It will be replaced with the generated files if it already exists. - adk_app (str): The name of the file (without .py) containing the AdkApp - instance. + adk_app (str): Deprecated. This argument is no longer required or used. staging_bucket (str): Deprecated. This argument is no longer required or used. - trace_to_cloud (bool): Whether to enable Cloud Trace. + trace_to_cloud (bool): Deprecated. This argument is no longer required or + used. otel_to_cloud (bool): Whether to enable exporting OpenTelemetry signals to Google Cloud. api_key (str): Optional. The API key to use for Express Mode. If not provided, the API key from the GOOGLE_API_KEY environment variable will be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. - adk_app_object (str): Optional. The Python object corresponding to the root - ADK agent or app. Defaults to `root_agent` if not specified. + adk_app_object (str): Deprecated. This argument is no longer required or + used. agent_engine_id (str): Optional. The ID of the Agent Engine instance to update. If not specified, a new Agent Engine instance will be created. - absolutize_imports (bool): Optional. Default is True. Whether to absolutize - imports. If True, all relative imports will be converted to absolute - import statements. + absolutize_imports (bool): Deprecated. This argument is no longer required + or used. project (str): Optional. Google Cloud project id for the deployed agent. If not specified, the project from the `GOOGLE_CLOUD_PROJECT` environment variable will be used. It will be ignored if `api_key` is specified. @@ -898,9 +872,8 @@ def to_agent_engine( variable will be used. It will be ignored if `api_key` is specified. display_name (str): Optional. The display name of the Agent Engine. description (str): Optional. The description of the Agent Engine. - requirements_file (str): Optional. The filepath to the `requirements.txt` - file to use. If not specified, the `requirements.txt` file in the - `agent_folder` will be used. + requirements_file (str): Deprecated. This argument is no longer required or + used. env_file (str): Optional. The filepath to the `.env` file for environment variables. If not specified, the `.env` file in the `agent_folder` will be used. The values of `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION` @@ -908,21 +881,33 @@ def to_agent_engine( agent_engine_config_file (str): The filepath to the agent engine config file to use. If not specified, the `.agent_engine_config.json` file in the `agent_folder` will be used. - skip_agent_import_validation (bool): Optional. Default is True. If True, - skip the pre-deployment import validation of `agent.py`. This can be - useful when the local environment does not have the same dependencies as - the deployment environment. + skip_agent_import_validation (bool): Deprecated. This argument is no longer + required or used. + trigger_sources (str): Optional. Comma-separated list of trigger sources to + enable (e.g., 'pubsub,eventarc'). Registers /trigger/* endpoints for + batch and event-driven agent invocations. + artifact_service_uri (str): Optional. The URI of the artifact service. + adk_version (str): Optional. The ADK version to use in Agent Engine + deployment. If not specified, the version in the dev environment will be + used. """ app_name = os.path.basename(agent_folder) display_name = display_name or app_name parent_folder = os.path.dirname(agent_folder) - adk_app_object = adk_app_object or 'root_agent' - if adk_app_object not in ['root_agent', 'app']: - click.echo( - f'Invalid adk_app_object: {adk_app_object}. Please use "root_agent"' - ' or "app".' + if adk_app_object: + warnings.warn( + 'WARNING: `--adk_app_object` is deprecated and will be removed in the' + ' future. Please drop it from the list of arguments.', + DeprecationWarning, + stacklevel=2, + ) + if adk_app: + warnings.warn( + 'WARNING: `adk_app` is deprecated and will be removed in a future' + ' release. Please drop it from the list of arguments.', + DeprecationWarning, + stacklevel=2, ) - return if staging_bucket: warnings.warn( 'WARNING: `staging_bucket` is deprecated and will be removed in a' @@ -930,6 +915,9 @@ def to_agent_engine( DeprecationWarning, stacklevel=2, ) + if not adk_version: + adk_version = _resolve_adk_version() + click.echo(f'Using default ADK version: {adk_version}') original_cwd = os.getcwd() did_change_cwd = False @@ -942,7 +930,7 @@ def to_agent_engine( did_change_cwd = True tmp_app_name = app_name + '_tmp' + datetime.now().strftime('%Y%m%d_%H%M%S') temp_folder = temp_folder or tmp_app_name - agent_src_path = os.path.join(parent_folder, temp_folder) + agent_src_path = os.path.join(parent_folder, temp_folder, 'agents', app_name) click.echo(f'Staging all files in: {agent_src_path}') # remove agent_src_path if it exists if os.path.exists(agent_src_path): @@ -965,6 +953,7 @@ def to_agent_engine( ignore=ignore_patterns, dirs_exist_ok=True, ) + os.chdir(os.path.join(parent_folder, temp_folder)) click.echo('Copying agent source code complete.') project = _resolve_project(project) @@ -1001,32 +990,20 @@ def to_agent_engine( ) agent_config['description'] = description - requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt') if requirements_file: - if os.path.exists(requirements_txt_path): - click.echo( - f'Overwriting {requirements_txt_path} with {requirements_file}' - ) - shutil.copyfile(requirements_file, requirements_txt_path) - elif 'requirements_file' in agent_config: - if os.path.exists(requirements_txt_path): - click.echo( - f'Overwriting {requirements_txt_path} with' - f' {agent_config["requirements_file"]}' - ) - shutil.copyfile(agent_config['requirements_file'], requirements_txt_path) - else: - # Attempt to read requirements from requirements.txt in the dir (if any). - if not os.path.exists(requirements_txt_path): - click.echo(f'Creating {requirements_txt_path}...') - with open(requirements_txt_path, 'w', encoding='utf-8') as f: - f.write('google-cloud-aiplatform[agent_engines]\n') - f.write(f'google-adk=={__version__}\n') - click.echo(f'Using google-adk=={__version__} in requirements') - click.echo(f'Created {requirements_txt_path}') - _ensure_agent_engine_dependency(requirements_txt_path) - agent_config['requirements_file'] = f'{temp_folder}/requirements.txt' - + warnings.warn( + 'WARNING: `--requirements_file` is deprecated and will be removed in' + ' the future. Please define `requirements.txt` in the agent folder.', + DeprecationWarning, + stacklevel=2, + ) + if trace_to_cloud: + warnings.warn( + 'WARNING: `--trace_to_cloud` is deprecated and will be removed in the' + ' future. Please use `--otel_to_cloud` instead.', + DeprecationWarning, + stacklevel=2, + ) env_vars = {} if not env_file: # Attempt to read the env variables from .env in the dir (if any). @@ -1082,6 +1059,18 @@ def to_agent_engine( fg='yellow', ) env_vars['GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY'] = 'true' + if 'ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS' not in env_vars: + env_vars['ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS'] = 'false' + else: + enable_telemetry = env_vars.get( + 'GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY', + ) + if enable_telemetry in ['true', '1']: + otel_to_cloud = True + click.echo( + '`--otel_to_cloud` is set to True by' + f' GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY in {env_file}' + ) if env_vars: if 'env_vars' in agent_config: click.echo( @@ -1109,80 +1098,90 @@ def to_agent_engine( ) click.echo('Vertex AI initialized.') - is_config_agent = False - config_root_agent_file = os.path.join(agent_src_path, 'root_agent.yaml') - if os.path.exists(config_root_agent_file): - click.echo(f'Config agent detected: {config_root_agent_file}') - is_config_agent = True - - # Validate that the agent module can be imported before deployment. - if not skip_agent_import_validation: - click.echo('Validating agent module...') - _validate_agent_import(agent_src_path, adk_app_object, is_config_agent) - - adk_app_file = os.path.join(temp_folder, f'{adk_app}.py') - if adk_app_object == 'root_agent': - adk_app_type = 'agent' - elif adk_app_object == 'app': - adk_app_type = 'app' - else: - click.echo( - f'Invalid adk_app_object: {adk_app_object}. Please use "root_agent"' - ' or "app".' + if skip_agent_import_validation: + warnings.warn( + 'WARNING: `--skip-agent-import-validation` is deprecated and will be' + ' removed in the future. Please drop it from the list of arguments.', + DeprecationWarning, + stacklevel=2, ) - return - extra_imports = '' - app_instantiation = f'{adk_app_type}={adk_app_object}' - if adk_app_type == 'agent': - extra_imports = 'from google.adk.apps import App' - app_instantiation = ( - f"app=App(name='{app_name}', root_agent={adk_app_object})" + + def create_dockerfile_for_agent_engine(resource_name: str): + requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt') + install_agent_deps = ( + f'RUN pip install -r "/app/agents/{app_name}/requirements.txt"' + if os.path.exists(requirements_txt_path) + else '# No requirements.txt found.' + ) + trigger_sources_option = ( + f'--trigger_sources={trigger_sources}' if trigger_sources else '' ) + agent_engine_uri = f'agentengine://{resource_name}' + dockerfile_content = _DOCKERFILE_TEMPLATE.format( + gcp_project_id=project, + gcp_region=region, + app_name=app_name, + port=8080, + command='api_server', + install_agent_deps=install_agent_deps, + service_option=_get_service_option_by_adk_version( + adk_version, + agent_engine_uri, # session_service_uri + artifact_service_uri, + agent_engine_uri, # memory_service_uri + False, # use_local_storage + ), + trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '', + otel_to_cloud_option='--otel_to_cloud' if otel_to_cloud else '', + allow_origins_option='', # Not supported for now. + adk_version=adk_version, + host_option='--host=0.0.0.0', + a2a_option='--a2a', + trigger_sources_option=trigger_sources_option, + gemini_enterprise_option=f'--gemini_enterprise_app_name={app_name}', + ) + with open('Dockerfile', 'w', encoding='utf-8') as f: + f.write(dockerfile_content) - template_content = _AGENT_ENGINE_APP_TEMPLATE.format( - app_name=app_name, - trace_to_cloud_option=trace_to_cloud, - is_config_agent=is_config_agent, - agent_folder=f'./{temp_folder}', - adk_app_object=adk_app_object, - app_instantiation=app_instantiation, - extra_imports=extra_imports, - express_mode=api_key is not None, - ) - with open(adk_app_file, 'w', encoding='utf-8') as f: - f.write(template_content) - click.echo(f'Created {adk_app_file}') - click.echo('Files and dependencies resolved') if absolutize_imports: - click.echo( - 'Agent Engine deployments have switched to source-based deployment, ' - 'so it is no longer necessary to absolutize imports.' + warnings.warn( + 'WARNING: `--absolutize_imports` is deprecated and will be removed' + ' in the future. Please drop it from the list of arguments.', + DeprecationWarning, + stacklevel=2, ) click.echo('Deploying to agent engine...') - agent_config['entrypoint_module'] = f'{temp_folder}.{adk_app}' - agent_config['entrypoint_object'] = 'adk_app' - agent_config['source_packages'] = [temp_folder] + agent_config['source_packages'] = [f'agents/{app_name}', 'Dockerfile'] + agent_config['image_spec'] = {} # Use the Dockerfile agent_config['class_methods'] = _AGENT_ENGINE_CLASS_METHODS agent_config['agent_framework'] = 'google-adk' - if not agent_engine_id: - agent_engine = client.agent_engines.create(config=agent_config) - click.secho( - f'✅ Created agent engine: {agent_engine.api_resource.name}', - fg='green', - ) - _print_agent_engine_url(agent_engine.api_resource.name) - else: - if project and region and not agent_engine_id.startswith('projects/'): - agent_engine_id = f'projects/{project}/locations/{region}/reasoningEngines/{agent_engine_id}' - client.agent_engines.update(name=agent_engine_id, config=agent_config) - click.secho(f'✅ Updated agent engine: {agent_engine_id}', fg='green') - _print_agent_engine_url(agent_engine_id) + resource_name = agent_engine_id + if not resource_name: + agent_engine = client.agent_engines.create() + resource_name = agent_engine.api_resource.name + click.secho(f'Created a new agent engine: {resource_name}', fg='green') + elif project and region and not resource_name.startswith('projects/'): + resource_name = f'projects/{project}/locations/{region}/reasoningEngines/{agent_engine_id}' + click.echo('Creating Dockerfile...') + create_dockerfile_for_agent_engine(resource_name) + click.echo(f'Dockerfile created at {os.getcwd()}/Dockerfile.') + try: + client.agent_engines.update(name=resource_name, config=agent_config) + click.secho(f'Deployed to agent engine: {resource_name}', fg='green') + except Exception as e: + click.secho(f'Failed to deploy to agent engine: {e}', fg='red') + # Only delete the agent engine if it was newly created in this function. + if agent_engine_id is None: + client.agent_engines.delete(name=resource_name) + click.secho(f'Cleaned up the agent engine: {resource_name}', fg='green') + raise e + _print_agent_engine_url(resource_name) finally: - click.echo(f'Cleaning up the temp folder: {temp_folder}') - shutil.rmtree(agent_src_path) - if did_change_cwd: - os.chdir(original_cwd) + temp_folder_path = os.path.join(parent_folder, temp_folder) + click.echo(f'Cleaning up the temp folder: {temp_folder_path}') + os.chdir(original_cwd) + shutil.rmtree(temp_folder_path) def to_gke( @@ -1289,7 +1288,7 @@ def to_gke( gcp_region=region, app_name=app_name, port=port, - command='api_server --with_ui' if with_ui else 'api_server', + command='web' if with_ui else 'api_server', install_agent_deps=install_agent_deps, service_option=_get_service_option_by_adk_version( adk_version, @@ -1307,6 +1306,7 @@ def to_gke( trigger_sources_option=( f'--trigger_sources={trigger_sources}' if trigger_sources else '' ), + gemini_enterprise_option='', ) dockerfile_path = os.path.join(temp_folder, 'Dockerfile') os.makedirs(temp_folder, exist_ok=True) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 8102f16573..ea33ede245 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -23,23 +23,25 @@ import logging import os from pathlib import Path -import sys import tempfile import textwrap -from typing import Optional import click from click.core import ParameterSource from fastapi import FastAPI import uvicorn +from . import cli_create +from . import cli_deploy from .. import version from ..agents.run_config import StreamingMode from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..features import FeatureName from ..features import override_feature_enabled from .cli import run_cli +from .fast_api import get_fast_api_app from .utils import envs +from .utils import evals from .utils import logs LOG_LEVELS = click.Choice( @@ -338,8 +340,8 @@ def cli_conformance_test( paths: tuple[str, ...], mode: str, generate_report: bool, - report_dir: Optional[str] = None, - streaming_mode: Optional[StreamingMode] = None, + report_dir: str | None = None, + streaming_mode: StreamingMode | None = None, ): """Run conformance tests to verify agent behavior consistency. @@ -472,11 +474,11 @@ def cli_conformance_test( @click.argument("app_name", type=str, required=True) def cli_create_cmd( app_name: str, - model: Optional[str], - api_key: Optional[str], - project: Optional[str], - region: Optional[str], - type: Optional[str], + model: str | None, + api_key: str | None, + project: str | None, + region: str | None, + type: str | None, ): """Creates a new app in the current folder with prepopulated agent template. @@ -486,8 +488,6 @@ def cli_create_cmd( adk create path/to/my_app """ - from . import cli_create - cli_create.run_cmd( app_name, model=model, @@ -646,207 +646,50 @@ def wrapper(*args, **kwargs): ), callback=validate_exclusive, ) -@click.option( - "--state", - type=str, - help="Optional. Initial state for the run as a JSON string.", -) -@click.option( - "--timeout", - type=str, - help="Optional. Timeout for a single turn or query (e.g., 30s, 5m).", -) -@click.option( - "--in_memory", - is_flag=True, - help="Optional. Do not persist session data (use in-memory storage).", -) -@click.option( - "--jsonl", - is_flag=True, - help="Optional. Output structured JSONL instead of human-readable text.", -) -@click.option( - "--default_llm_model", - type=str, - help=( - "Optional. Sets the default LLM model used when the agent does not set" - " a model explicitly." - ), - default=None, -) @click.argument( "agent", type=click.Path( exists=True, dir_okay=True, file_okay=False, resolve_path=True ), ) -@click.argument("query", type=str, required=False) def cli_run( agent: str, - query: Optional[str], save_session: bool, - session_id: Optional[str], - replay: Optional[str], - resume: Optional[str], - state: Optional[str] = None, - timeout: Optional[str] = None, - in_memory: bool = False, - jsonl: bool = False, - session_service_uri: Optional[str] = None, - artifact_service_uri: Optional[str] = None, - memory_service_uri: Optional[str] = None, + session_id: str | None, + replay: str | None, + resume: str | None, + session_service_uri: str | None = None, + artifact_service_uri: str | None = None, + memory_service_uri: str | None = None, use_local_storage: bool = True, - default_llm_model: Optional[str] = None, ): - """Runs an agent. If no query is provided, enters interactive mode. + """Runs an interactive CLI for a certain agent. AGENT: The path to the agent source code folder. - QUERY: Optional. The user message to send to the agent for a single-step run. Example: adk run path/to/my_agent - adk run path/to/my_agent "hello" """ logs.log_to_tmp_folder() agent_parent_folder = os.path.dirname(agent) agent_folder_name = os.path.basename(agent) - # If query is provided, we run in single-step mode (JSONL output) - if query is not None: - from .cli import run_once_cli - - exit_code = asyncio.run( - run_once_cli( - agent_parent_dir=agent_parent_folder, - agent_folder_name=agent_folder_name, - query=query, - state_str=state, - session_id=session_id, - replay=replay, - timeout=timeout, - in_memory=in_memory, - jsonl=jsonl, - session_service_uri=session_service_uri, - artifact_service_uri=artifact_service_uri, - memory_service_uri=memory_service_uri, - use_local_storage=use_local_storage, - default_llm_model=default_llm_model, - ) - ) - sys.exit(exit_code) - else: - # Legacy interactive mode - asyncio.run( - run_cli( - agent_parent_dir=agent_parent_folder, - agent_folder_name=agent_folder_name, - input_file=replay, - saved_session_file=resume, - save_session=save_session, - session_id=session_id, - state_str=state, - timeout=timeout, - in_memory=in_memory, - jsonl=jsonl, - session_service_uri=session_service_uri, - artifact_service_uri=artifact_service_uri, - memory_service_uri=memory_service_uri, - use_local_storage=use_local_storage, - default_llm_model=default_llm_model, - ) - ) - - -@main.command( - "test", - cls=HelpfulCommand, - context_settings={ - "allow_extra_args": True, - "allow_interspersed_args": True, - "ignore_unknown_options": True, - }, -) -@click.argument( - "folder", - type=click.Path( - exists=True, dir_okay=True, file_okay=False, resolve_path=True - ), - default=".", -) -@click.option( - "--rebuild", - is_flag=True, - help="Rebuild test files by running the real agent with user messages.", -) -@click.pass_context -def cli_test(ctx, folder: str, rebuild: bool): - """Runs pytest on agent test JSON files under the specified folder. - - FOLDER: The path to the folder containing agents and tests. - Defaults to the current directory if not specified. - - Example: - adk test path/to/agents - """ - import sys - - if rebuild: - from .agent_test_runner import rebuild_tests - - click.echo(f"Rebuilding tests in {folder}...") - rebuild_tests(folder) - sys.exit(0) - - # Parse arguments to separate pytest args (after --) from regular args - pytest_args = [] - if "--" in ctx.args: - separator_index = ctx.args.index("--") - pytest_args = ctx.args[separator_index + 1 :] - regular_args = ctx.args[:separator_index] - - if regular_args: - click.secho( - "Error: Unexpected arguments after folder and before '--':" - f" {' '.join(regular_args)}. \nOnly arguments after '--' are passed" - " to pytest.", - fg="red", - err=True, + asyncio.run( + run_cli( + agent_parent_dir=agent_parent_folder, + agent_folder_name=agent_folder_name, + input_file=replay, + saved_session_file=resume, + save_session=save_session, + session_id=session_id, + session_service_uri=session_service_uri, + artifact_service_uri=artifact_service_uri, + memory_service_uri=memory_service_uri, + use_local_storage=use_local_storage, ) - ctx.exit(2) - else: - # If no '--', all remaining arguments are passed to pytest - pytest_args = ctx.args - - import subprocess - - os.environ["ADK_TEST_FOLDER"] = folder - - current_dir = Path(__file__).parent - test_runner_path = current_dir / "agent_test_runner.py" - - if not test_runner_path.exists(): - click.secho( - f"Error: Test runner not found at {test_runner_path}", - fg="red", - err=True, - ) - sys.exit(1) - - click.echo(f"Running tests in {folder} using runner {test_runner_path}...") - - result = subprocess.run([ - sys.executable, - "-m", - "pytest", - str(test_runner_path), - "-v", - "-s", - *pytest_args, - ]) - sys.exit(result.returncode) + ) def eval_options(): @@ -900,7 +743,7 @@ def cli_eval( eval_set_file_path_or_id: list[str], config_file_path: str, print_detailed_results: bool, - eval_storage_uri: Optional[str] = None, + eval_storage_uri: str | None = None, log_level: str = "INFO", ): """Evaluates an agent given the eval sets. @@ -997,8 +840,6 @@ def cli_eval( eval_set_results_manager = None if eval_storage_uri: - from .utils import evals - gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( eval_storage_uri ) @@ -1292,7 +1133,7 @@ def eval_set(): def cli_create_eval_set( agent_module_file_path: str, eval_set_id: str, - eval_storage_uri: Optional[str] = None, + eval_storage_uri: str | None = None, log_level: str = "INFO", ): """Creates an empty EvalSet given the agent_module_file_path and eval_set_id.""" @@ -1341,8 +1182,8 @@ def cli_add_eval_case( agent_module_file_path: str, eval_set_id: str, scenarios_file: str, - eval_storage_uri: Optional[str] = None, - session_input_file: Optional[str] = None, + eval_storage_uri: str | None = None, + session_input_file: str | None = None, log_level: str = "INFO", ): """Adds eval cases to the given eval set. @@ -1431,7 +1272,7 @@ def cli_generate_eval_cases( agent_module_file_path: str, eval_set_id: str, user_simulation_config_file: str, - eval_storage_uri: Optional[str] = None, + eval_storage_uri: str | None = None, log_level: str = "INFO", ): """Generates eval cases dynamically and adds them to the given eval set. @@ -1556,7 +1397,7 @@ def wrapper(*args, **kwargs): return decorator -def _deprecate_staging_bucket(ctx, param, value): +def _deprecate_parameter(ctx, param, value): if value: click.echo( click.style( @@ -1569,6 +1410,19 @@ def _deprecate_staging_bucket(ctx, param, value): return value +def _deprecate_trace_to_cloud(ctx, param, value): + if value: + click.echo( + click.style( + f"WARNING: --{param} is deprecated and will be removed. Please" + " use --otel_to_cloud instead.", + fg="yellow", + ), + err=True, + ) + return value + + def fast_api_common_options(): """Decorator to add common fast api options to click commands.""" @@ -1714,34 +1568,11 @@ def wrapper(ctx, *args, **kwargs): return decorator -def _check_windows_reload(reload: bool) -> bool: - """Checks if reload is enabled on Windows and forces it to False if so.""" - if sys.platform == "win32" and reload: - click.secho( - "WARNING: The --reload flag is not supported on Windows because it" - " forces Uvicorn to use SelectorEventLoop, which does not support" - " subprocesses (needed for executing tools). Forcing --no-reload.", - fg="yellow", - err=True, - ) - return False - return reload - - @main.command("web") @feature_options() @fast_api_common_options() @web_options() @adk_services_options(default_use_local_storage=True) -@click.option( - "--default_llm_model", - type=str, - help=( - "Optional. Sets the default LLM model used when the agent does not set" - " a model explicitly." - ), - default=None, -) @click.argument( "agents_dir", type=click.Path( @@ -1751,26 +1582,25 @@ def _check_windows_reload(reload: bool) -> bool: ) def cli_web( agents_dir: str, - default_llm_model: Optional[str] = None, - eval_storage_uri: Optional[str] = None, + eval_storage_uri: str | None = None, log_level: str = "INFO", - allow_origins: Optional[list[str]] = None, + allow_origins: list[str] | None = None, host: str = "127.0.0.1", port: int = 8000, - url_prefix: Optional[str] = None, + url_prefix: str | None = None, trace_to_cloud: bool = False, otel_to_cloud: bool = False, reload: bool = True, - session_service_uri: Optional[str] = None, - artifact_service_uri: Optional[str] = None, - memory_service_uri: Optional[str] = None, + session_service_uri: str | None = None, + artifact_service_uri: str | None = None, + memory_service_uri: str | None = None, use_local_storage: bool = True, a2a: bool = False, reload_agents: bool = False, - extra_plugins: Optional[list[str]] = None, - logo_text: Optional[str] = None, - logo_image_url: Optional[str] = None, - trigger_sources: Optional[list[str]] = None, + extra_plugins: list[str] | None = None, + logo_text: str | None = None, + logo_image_url: str | None = None, + trigger_sources: list[str] | None = None, ): """Starts a FastAPI server with Web UI for agents. @@ -1781,7 +1611,6 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - reload = _check_windows_reload(reload) logs.setup_adk_logger(getattr(logging, log_level.upper())) @asynccontextmanager @@ -1806,8 +1635,6 @@ async def _lifespan(app: FastAPI): fg="green", ) - from .fast_api import get_fast_api_app - app = get_fast_api_app( agents_dir=agents_dir, session_service_uri=session_service_uri, @@ -1829,7 +1656,6 @@ async def _lifespan(app: FastAPI): logo_text=logo_text, logo_image_url=logo_image_url, trigger_sources=trigger_sources, - default_llm_model=default_llm_model, ) config = uvicorn.Config( app, @@ -1864,32 +1690,35 @@ async def _lifespan(app: FastAPI): ), ) @click.option( - "--with_ui", - is_flag=True, - default=False, - help="Serve ADK Web UI if set.", + "--gemini_enterprise_app_name", + type=str, + default=None, + help=( + "The app_name to register with Gemini Enterprise via" + " https://docs.cloud.google.com/gemini/enterprise/docs/register-and-manage-an-adk-agent" + ), ) def cli_api_server( agents_dir: str, - eval_storage_uri: Optional[str] = None, + eval_storage_uri: str | None = None, log_level: str = "INFO", - allow_origins: Optional[list[str]] = None, + allow_origins: list[str] | None = None, host: str = "127.0.0.1", port: int = 8000, - url_prefix: Optional[str] = None, + url_prefix: str | None = None, trace_to_cloud: bool = False, otel_to_cloud: bool = False, reload: bool = True, - session_service_uri: Optional[str] = None, - artifact_service_uri: Optional[str] = None, - memory_service_uri: Optional[str] = None, + session_service_uri: str | None = None, + artifact_service_uri: str | None = None, + memory_service_uri: str | None = None, use_local_storage: bool = True, a2a: bool = False, reload_agents: bool = False, - extra_plugins: Optional[list[str]] = None, + extra_plugins: list[str] | None = None, auto_create_session: bool = False, - trigger_sources: Optional[list[str]] = None, - with_ui: bool = False, + trigger_sources: list[str] | None = None, + gemini_enterprise_app_name: str | None = None, ): """Starts a FastAPI server for agents. @@ -1900,11 +1729,8 @@ def cli_api_server( adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - reload = _check_windows_reload(reload) logs.setup_adk_logger(getattr(logging, log_level.upper())) - from .fast_api import get_fast_api_app - config = uvicorn.Config( get_fast_api_app( agents_dir=agents_dir, @@ -1914,7 +1740,7 @@ def cli_api_server( use_local_storage=use_local_storage, eval_storage_uri=eval_storage_uri, allow_origins=allow_origins, - web=with_ui, + web=False, trace_to_cloud=trace_to_cloud, otel_to_cloud=otel_to_cloud, a2a=a2a, @@ -1925,6 +1751,7 @@ def cli_api_server( extra_plugins=extra_plugins, auto_create_session=auto_create_session, trigger_sources=trigger_sources, + gemini_enterprise_app_name=gemini_enterprise_app_name, ), host=host, port=port, @@ -2080,8 +1907,8 @@ def cli_api_server( def cli_deploy_cloud_run( ctx, agent: str, - project: Optional[str], - region: Optional[str], + project: str | None, + region: str | None, service_name: str, app_name: str, temp_folder: str, @@ -2097,7 +1924,7 @@ def cli_deploy_cloud_run( memory_service_uri: Optional[str] = None, use_local_storage: bool = False, a2a: bool = False, - trigger_sources: Optional[str] = None, + trigger_sources: str | None = None, ): """Deploys an agent to Cloud Run. @@ -2145,8 +1972,6 @@ def cli_deploy_cloud_run( ctx.exit(2) try: - from . import cli_deploy - cli_deploy.to_cloud_run( agent_folder=agent, project=project, @@ -2255,7 +2080,7 @@ def cli_migrate_session( type=str, default=None, help="Deprecated. This argument is no longer required or used.", - callback=_deprecate_staging_bucket, + callback=_deprecate_parameter, ) @click.option( "--agent_engine_id", @@ -2277,7 +2102,8 @@ def cli_migrate_session( is_flag=True, show_default=True, default=None, - help="Optional. Whether to enable Cloud Trace for Agent Engine.", + help=" NOTE: This flag is deprecated and will be removed in the future.", + callback=_deprecate_trace_to_cloud, ) @click.option( "--otel_to_cloud", @@ -2304,11 +2130,9 @@ def cli_migrate_session( @click.option( "--adk_app", type=str, - default="agent_engine_app", - help=( - "Optional. Python file for defining the ADK application" - " (default: a file named agent_engine_app.py)" - ), + default=None, + help=" NOTE: This flag is deprecated and will be removed in the future.", + callback=_deprecate_parameter, ) @click.option( "--temp_folder", @@ -2324,35 +2148,29 @@ def cli_migrate_session( "--adk_app_object", type=str, default=None, - help=( - "Optional. Python object corresponding to the root ADK agent or app." - " It can only be `root_agent` or `app`. (default: `root_agent`)" - ), + help=" NOTE: This flag is deprecated and will be removed in the future.", + callback=_deprecate_parameter, ) @click.option( "--env_file", type=str, default="", - help=( - "Optional. The filepath to the `.env` file for environment variables." - " (default: the `.env` file in the `agent` directory, if any.)" - ), + help=" NOTE: This flag is deprecated and will be removed in the future.", + callback=_deprecate_parameter, ) @click.option( "--requirements_file", type=str, default="", - help=( - "Optional. The filepath to the `requirements.txt` file to use." - " (default: the `requirements.txt` file in the `agent` directory, if" - " any.)" - ), + help=" NOTE: This flag is deprecated and will be removed in the future.", + callback=_deprecate_parameter, ) @click.option( "--absolutize_imports", type=bool, default=False, help=" NOTE: This flag is deprecated and will be removed in the future.", + callback=_deprecate_parameter, ) @click.option( "--agent_engine_config_file", @@ -2368,21 +2186,52 @@ def cli_migrate_session( @click.option( "--validate-agent-import/--no-validate-agent-import", default=False, - help=( - "Optional. Validate that the agent module can be imported before" - " deployment. This requires your local environment to have the same" - " dependencies as the deployment environment. (default: disabled)" - ), + help=" NOTE: This flag is deprecated and will be removed in the future.", + callback=_deprecate_parameter, ) @click.option( "--skip-agent-import-validation", "skip_agent_import_validation_alias", is_flag=True, default=False, + help=" NOTE: This flag is deprecated and will be removed in the future.", + callback=_deprecate_parameter, +) +# Kept as raw str (not parsed to list) — interpolated directly into Dockerfile CMD. +@click.option( + "--trigger_sources", + type=str, help=( - "Optional. Skip pre-deployment import validation of `agent.py`. This is" - " the default; use --validate-agent-import to enable validation." + "Optional. Comma-separated list of trigger sources to enable" + " (e.g., 'pubsub,eventarc'). Registers /trigger/* endpoints" + " for batch and event-driven agent invocations." ), + default=None, +) +@click.option( + "--adk_version", + type=str, + default=version.__version__, + show_default=True, + help=( + "Optional. The ADK version used in Agent Engine deployment. (default: " + " the version in the dev environment)" + ), +) +@click.option( + "--artifact_service_uri", + type=str, + help=textwrap.dedent( + """\ + Optional. The URI of the artifact service. If set, ADK uses this service. + + \b + If unset, ADK chooses a default artifact service. + - Use 'gs://' to connect to the GCS artifact service. + - Use 'memory://' to force the in-memory artifact service. + - Use 'file://' to store artifacts in a custom local directory.""" + ), + default=None, ) @click.argument( "agent", @@ -2392,24 +2241,27 @@ def cli_migrate_session( ) def cli_deploy_agent_engine( agent: str, - project: Optional[str], - region: Optional[str], - staging_bucket: Optional[str], - agent_engine_id: Optional[str], - trace_to_cloud: Optional[bool], - otel_to_cloud: Optional[bool], - api_key: Optional[str], + project: str | None, + region: str | None, + staging_bucket: str | None, + agent_engine_id: str | None, + trace_to_cloud: bool | None, + otel_to_cloud: bool | None, + api_key: str | None, display_name: str, description: str, - adk_app: str, - adk_app_object: Optional[str], - temp_folder: Optional[str], + adk_app: str | None, + adk_app_object: str | None, + temp_folder: str | None, env_file: str, requirements_file: str, absolutize_imports: bool, agent_engine_config_file: str, validate_agent_import: bool = False, skip_agent_import_validation_alias: bool = False, + adk_version: str | None = None, + trigger_sources: str | None = None, + artifact_service_uri: str | None = None, ): """Deploys an agent to Agent Engine. @@ -2431,8 +2283,6 @@ def cli_deploy_agent_engine( "Do not pass both --validate-agent-import and" " --skip-agent-import-validation." ) - from . import cli_deploy - cli_deploy.to_agent_engine( agent_folder=agent, project=project, @@ -2451,6 +2301,9 @@ def cli_deploy_agent_engine( absolutize_imports=absolutize_imports, agent_engine_config_file=agent_engine_config_file, skip_agent_import_validation=not validate_agent_import, + trigger_sources=trigger_sources, + artifact_service_uri=artifact_service_uri, + adk_version=adk_version, ) except Exception as e: click.secho(f"Deploy failed: {e}", fg="red", err=True) @@ -2587,8 +2440,8 @@ def cli_deploy_agent_engine( ) def cli_deploy_gke( agent: str, - project: Optional[str], - region: Optional[str], + project: str | None, + region: str | None, cluster_name: str, service_name: str, app_name: str, @@ -2599,12 +2452,12 @@ def cli_deploy_gke( with_ui: bool, adk_version: str, service_type: str, - log_level: Optional[str] = None, - session_service_uri: Optional[str] = None, - artifact_service_uri: Optional[str] = None, - memory_service_uri: Optional[str] = None, + log_level: str | None = None, + session_service_uri: str | None = None, + artifact_service_uri: str | None = None, + memory_service_uri: str | None = None, use_local_storage: bool = False, - trigger_sources: Optional[str] = None, + trigger_sources: str | None = None, ): """Deploys an agent to GKE. @@ -2617,8 +2470,6 @@ def cli_deploy_gke( """ try: _warn_if_with_ui(with_ui) - from . import cli_deploy - cli_deploy.to_gke( agent_folder=agent, project=project, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index ed99799ca4..e13f6ee29d 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from contextlib import asynccontextmanager import importlib @@ -19,37 +20,57 @@ import logging import os from pathlib import Path -import sys +import shutil from typing import Any +from typing import AsyncIterator +from typing import Awaitable +from typing import Callable from typing import Literal from typing import Mapping -from typing import Optional import click from fastapi import FastAPI -from fastapi import File from fastapi import HTTPException +from fastapi import Request from fastapi import UploadFile +from fastapi.encoders import jsonable_encoder from fastapi.responses import FileResponse +from fastapi.responses import JSONResponse from fastapi.responses import PlainTextResponse +from fastapi.responses import StreamingResponse +from opentelemetry import context +from opentelemetry import trace from opentelemetry.sdk.trace import export from opentelemetry.sdk.trace import TracerProvider +from pydantic import BaseModel +from starlette.concurrency import run_in_threadpool from starlette.types import Lifespan from watchdog.observers import Observer from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager +from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..runners import Runner -from .api_server import ApiServer -from .dev_server import DevServer +from .adk_web_server import AdkWebServer from .service_registry import load_services_module +from ..telemetry._agent_engine import get_propagated_context +from ..telemetry._agent_engine import TopSpanProcessor from .utils import envs +from .utils import evals from .utils.agent_change_handler import AgentChangeEventHandler +from .utils.agent_loader import AgentLoader from .utils.base_agent_loader import BaseAgentLoader from .utils.service_factory import _create_task_store_from_options from .utils.service_factory import create_artifact_service_from_options from .utils.service_factory import create_memory_service_from_options from .utils.service_factory import create_session_service_from_options + +class _QueryRequest(BaseModel): + input: dict[str, Any] | None = None + class_method: str | None = None + + logger = logging.getLogger("google_adk." + __name__) _LAZY_SERVICE_IMPORTS: dict[str, str] = { @@ -70,309 +91,6 @@ def __getattr__(name: str): return attr -def _register_builder_endpoints(app: FastAPI, web: bool, agents_dir: str): - """Registers builder endpoints if web is enabled and multipart is installed.""" - if not web: - return - try: - import multipart - except ImportError: - logger.warning( - "python-multipart not installed. Builder UI endpoints will not be" - " available." - ) - return - - import shutil - - import yaml - - agents_base_path = (Path.cwd() / agents_dir).resolve() - - def _get_app_root(app_name: str) -> Path: - if app_name in ("", ".", ".."): - raise ValueError(f"Invalid app name: {app_name!r}") - if Path(app_name).name != app_name or "\\" in app_name: - raise ValueError(f"Invalid app name: {app_name!r}") - app_root = (agents_base_path / app_name).resolve() - if not app_root.is_relative_to(agents_base_path): - raise ValueError(f"Invalid app name: {app_name!r}") - return app_root - - def _normalize_relative_path(path: str) -> str: - return path.replace("\\", "/").lstrip("/") - - def _has_parent_reference(path: str) -> bool: - return any(part == ".." for part in path.split("/")) - - _ALLOWED_EXTENSIONS = frozenset({".yaml", ".yml"}) - - _BLOCKED_YAML_KEYS = frozenset({"args"}) - - def _check_yaml_for_blocked_keys(content: bytes, filename: str) -> None: - try: - docs = list(yaml.safe_load_all(content)) - except yaml.YAMLError as exc: - raise ValueError(f"Invalid YAML in {filename!r}: {exc}") from exc - - def _walk(node: Any) -> None: - if isinstance(node, dict): - for key, value in node.items(): - if key in _BLOCKED_YAML_KEYS: - raise ValueError( - f"Blocked key {key!r} found in {filename!r}. " - f"The '{key}' field is not allowed in builder uploads " - "because it can execute arbitrary code." - ) - _walk(value) - elif isinstance(node, list): - for item in node: - _walk(item) - - for doc in docs: - _walk(doc) - - def _parse_upload_filename(filename: Optional[str]) -> tuple[str, str]: - if not filename: - raise ValueError("Upload filename is missing.") - filename = _normalize_relative_path(filename) - if "/" not in filename: - raise ValueError(f"Invalid upload filename: {filename!r}") - app_name, rel_path = filename.split("/", 1) - if not app_name or not rel_path: - raise ValueError(f"Invalid upload filename: {filename!r}") - if rel_path.startswith("/"): - raise ValueError(f"Absolute upload path rejected: {filename!r}") - if _has_parent_reference(rel_path): - raise ValueError(f"Path traversal rejected: {filename!r}") - ext = os.path.splitext(rel_path)[1].lower() - if ext not in _ALLOWED_EXTENSIONS: - raise ValueError( - f"File type not allowed: {rel_path!r}" - f" (allowed: {', '.join(sorted(_ALLOWED_EXTENSIONS))})" - ) - return app_name, rel_path - - def _parse_file_path(file_path: str) -> str: - file_path = _normalize_relative_path(file_path) - if not file_path: - raise ValueError("file_path is missing.") - if file_path.startswith("/"): - raise ValueError(f"Absolute file_path rejected: {file_path!r}") - if _has_parent_reference(file_path): - raise ValueError(f"Path traversal rejected: {file_path!r}") - ext = os.path.splitext(file_path)[1].lower() - if ext not in _ALLOWED_EXTENSIONS: - raise ValueError( - f"File type not allowed: {file_path!r}" - f" (allowed: {', '.join(sorted(_ALLOWED_EXTENSIONS))})" - ) - return file_path - - def _resolve_under_dir(root_dir: Path, rel_path: str) -> Path: - file_path = root_dir / rel_path - resolved_root_dir = root_dir.resolve() - resolved_file_path = file_path.resolve() - if not resolved_file_path.is_relative_to(resolved_root_dir): - raise ValueError(f"Path escapes root_dir: {rel_path!r}") - return file_path - - def _get_tmp_agent_root(app_root: Path, app_name: str) -> Path: - tmp_agent_root = app_root / "tmp" / app_name - resolved_tmp_agent_root = tmp_agent_root.resolve() - if not resolved_tmp_agent_root.is_relative_to(app_root): - raise ValueError(f"Invalid tmp path for app: {app_name!r}") - return tmp_agent_root - - def copy_dir_contents(source_dir: Path, dest_dir: Path) -> None: - dest_dir.mkdir(parents=True, exist_ok=True) - for source_path in source_dir.iterdir(): - if source_path.name == "tmp": - continue - - dest_path = dest_dir / source_path.name - if source_path.is_dir(): - if dest_path.exists() and dest_path.is_file(): - dest_path.unlink() - shutil.copytree(source_path, dest_path, dirs_exist_ok=True) - elif source_path.is_file(): - if dest_path.exists() and dest_path.is_dir(): - shutil.rmtree(dest_path) - shutil.copy2(source_path, dest_path) - - def cleanup_tmp(app_name: str) -> bool: - try: - app_root = _get_app_root(app_name) - except ValueError as exc: - logger.exception("Error in cleanup_tmp: %s", exc) - return False - - try: - tmp_agent_root = _get_tmp_agent_root(app_root, app_name) - except ValueError as exc: - logger.exception("Error in cleanup_tmp: %s", exc) - return False - - try: - shutil.rmtree(tmp_agent_root) - except FileNotFoundError: - pass - except OSError as exc: - logger.exception("Error deleting tmp agent root: %s", exc) - return False - - tmp_dir = app_root / "tmp" - resolved_tmp_dir = tmp_dir.resolve() - if not resolved_tmp_dir.is_relative_to(app_root): - logger.error( - "Refusing to delete tmp outside app_root: %s", resolved_tmp_dir - ) - return False - - try: - tmp_dir.rmdir() - except OSError: - pass - - return True - - def ensure_tmp_exists(app_name: str) -> bool: - try: - app_root = _get_app_root(app_name) - except ValueError as exc: - logger.exception("Error in ensure_tmp_exists: %s", exc) - return False - - if not app_root.is_dir(): - return False - - try: - tmp_agent_root = _get_tmp_agent_root(app_root, app_name) - except ValueError as exc: - logger.exception("Error in ensure_tmp_exists: %s", exc) - return False - - if tmp_agent_root.exists(): - return True - - try: - tmp_agent_root.mkdir(parents=True, exist_ok=True) - copy_dir_contents(app_root, tmp_agent_root) - except OSError as exc: - logger.exception("Error in ensure_tmp_exists: %s", exc) - return False - - return True - - @app.post("/builder/save", response_model_exclude_none=True) - async def builder_build( - files: list[UploadFile] = File(...), tmp: Optional[bool] = False - ) -> bool: - try: - app_names: set[str] = set() - uploads: list[tuple[str, bytes]] = [] - for file in files: - app_name, rel_path = _parse_upload_filename(file.filename) - app_names.add(app_name) - content = await file.read() - uploads.append((rel_path, content)) - - if len(app_names) != 1: - logger.error( - "Exactly one app name is required, found: %s", - sorted(app_names), - ) - return False - - app_name = next(iter(app_names)) - - for rel_path, content in uploads: - _check_yaml_for_blocked_keys(content, f"{app_name}/{rel_path}") - - if tmp: - app_root = _get_app_root(app_name) - tmp_agent_root = _get_tmp_agent_root(app_root, app_name) - tmp_agent_root.mkdir(parents=True, exist_ok=True) - - for rel_path, content in uploads: - destination_path = _resolve_under_dir(tmp_agent_root, rel_path) - destination_path.parent.mkdir(parents=True, exist_ok=True) - destination_path.write_bytes(content) - - return True - - app_root = _get_app_root(app_name) - app_root.mkdir(parents=True, exist_ok=True) - - tmp_agent_root = _get_tmp_agent_root(app_root, app_name) - if tmp_agent_root.is_dir(): - copy_dir_contents(tmp_agent_root, app_root) - - for rel_path, content in uploads: - destination_path = _resolve_under_dir(app_root, rel_path) - destination_path.parent.mkdir(parents=True, exist_ok=True) - destination_path.write_bytes(content) - - return cleanup_tmp(app_name) - except ValueError as exc: - logger.exception("Error in builder_build: %s", exc) - raise HTTPException(status_code=400, detail=str(exc)) - except OSError as exc: - logger.exception("Error in builder_build: %s", exc) - return False - - @app.post("/builder/app/{app_name}/cancel", response_model_exclude_none=True) - async def builder_cancel(app_name: str) -> bool: - return cleanup_tmp(app_name) - - @app.get( - "/builder/app/{app_name}", - response_model_exclude_none=True, - response_class=PlainTextResponse, - ) - async def get_agent_builder( - app_name: str, - file_path: Optional[str] = None, - tmp: Optional[bool] = False, - ): - try: - app_root = _get_app_root(app_name) - except ValueError as exc: - logger.exception("Error in get_agent_builder: %s", exc) - return "" - - agent_dir = app_root - if tmp: - if not ensure_tmp_exists(app_name): - return "" - agent_dir = app_root / "tmp" / app_name - - if not file_path: - rel_path = "root_agent.yaml" - else: - try: - rel_path = _parse_file_path(file_path) - except ValueError as exc: - logger.exception("Error in get_agent_builder: %s", exc) - return "" - - try: - agent_file_path = _resolve_under_dir(agent_dir, rel_path) - except ValueError as exc: - logger.exception("Error in get_agent_builder: %s", exc) - return "" - - if not agent_file_path.is_file(): - return "" - - return FileResponse( - path=agent_file_path, - media_type="application/x-yaml", - filename=file_path or f"{app_name}.yaml", - headers={"Cache-Control": "no-store"}, - ) - - def get_fast_api_app( *, agents_dir: str, @@ -399,7 +117,7 @@ def get_fast_api_app( logo_image_url: str | None = None, auto_create_session: bool = False, trigger_sources: list[Literal["pubsub", "eventarc"]] | None = None, - default_llm_model: str | None = None, + gemini_enterprise_app_name: str | None = None, ) -> FastAPI: """Constructs and returns a FastAPI application for serving ADK agents. @@ -448,6 +166,8 @@ def get_fast_api_app( trigger_sources: List of trigger sources to enable (e.g. ["pubsub", "eventarc"]). When set, registers /trigger/* endpoints for batch and event-driven agent invocations. None disables all trigger endpoints. + gemini_enterprise_app_name: The app_name to register with Gemini Enterprise + via https://docs.cloud.google.com/gemini/enterprise/docs/register-and-manage-an-adk-agent Returns: The configured FastAPI application instance. @@ -461,24 +181,18 @@ def get_fast_api_app( # Set up eval managers. if eval_storage_uri: - from .utils import evals - gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( eval_storage_uri ) eval_sets_manager = gcs_eval_managers.eval_sets_manager eval_set_results_manager = gcs_eval_managers.eval_set_results_manager else: - this_module = sys.modules[__name__] - eval_sets_manager = this_module.LocalEvalSetsManager(agents_dir=agents_dir) - eval_set_results_manager = this_module.LocalEvalSetResultsManager( - agents_dir=agents_dir - ) + eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) + eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) # initialize Agent Loader if not passed as argument if agent_loader is None: - this_module = sys.modules[__name__] - agent_loader = this_module.AgentLoader(agents_dir) + agent_loader = AgentLoader(agents_dir) # Load services.py from agents_dir for custom service registration. load_services_module(agents_dir) @@ -514,12 +228,7 @@ def get_fast_api_app( # Build the Credential service credential_service = InMemoryCredentialService() - # Instantiate the appropriate server class based on web option - # If web=True, use DevServer (includes all endpoints: production + dev) - # If web=False, use ApiServer (production-safe endpoints only) - ServerClass = DevServer if web else ApiServer - - adk_web_server = ServerClass( + adk_web_server = AdkWebServer( agent_loader=agent_loader, session_service=session_service, artifact_service=artifact_service, @@ -534,7 +243,6 @@ def get_fast_api_app( url_prefix=url_prefix, auto_create_session=auto_create_session, trigger_sources=trigger_sources, - default_llm_model=default_llm_model, ) # Callbacks & other optional args for when constructing the FastAPI instance @@ -564,7 +272,7 @@ def register_processors(provider: TracerProvider) -> None: if reload_agents: - def setup_observer(observer: Observer, adk_web_server: ApiServer): + def setup_observer(observer: Observer, adk_web_server: AdkWebServer): agent_change_handler = AgentChangeEventHandler( agent_loader=agent_loader, runners_to_clean=adk_web_server.runners_to_clean, @@ -573,7 +281,7 @@ def setup_observer(observer: Observer, adk_web_server: ApiServer): observer.schedule(agent_change_handler, agents_dir, recursive=True) observer.start() - def tear_down_observer(observer: Observer, _: ApiServer): + def tear_down_observer(observer: Observer, _: AdkWebServer): observer.stop() observer.join() @@ -624,7 +332,309 @@ async def _a2a_lifespan(app_instance: FastAPI): ) # --- Builder endpoints (agent editor UI) --- - _register_builder_endpoints(app, web, agents_dir) + # Only register when the web UI is enabled. In headless / production + # deployments (e.g. `adk deploy cloud_run`) these endpoints are unnecessary + # and expose an attack surface that allows arbitrary file writes under the + # agents directory. + # See https://github.com/google/adk-python/issues/4947 + if web: + agents_base_path = (Path.cwd() / agents_dir).resolve() + + def _get_app_root(app_name: str) -> Path: + if app_name in ("", ".", ".."): + raise ValueError(f"Invalid app name: {app_name!r}") + if Path(app_name).name != app_name or "\\" in app_name: + raise ValueError(f"Invalid app name: {app_name!r}") + app_root = (agents_base_path / app_name).resolve() + if not app_root.is_relative_to(agents_base_path): + raise ValueError(f"Invalid app name: {app_name!r}") + return app_root + + def _normalize_relative_path(path: str) -> str: + return path.replace("\\", "/").lstrip("/") + + def _has_parent_reference(path: str) -> bool: + return any(part == ".." for part in path.split("/")) + + _ALLOWED_EXTENSIONS = frozenset({".yaml", ".yml"}) + + # --- YAML content security --- + # The `args` key in agent YAML configs (CodeConfig.args, ToolConfig.args) + # allows callers to pass arbitrary arguments to Python constructors and + # functions, which is an RCE vector when exposed through the builder UI. + # Block any upload that contains an `args` key anywhere in the document. + _BLOCKED_YAML_KEYS = frozenset({"args"}) + + def _check_yaml_for_blocked_keys(content: bytes, filename: str) -> None: + """Raise if the YAML document contains any blocked keys.""" + import yaml + + try: + docs = list(yaml.safe_load_all(content)) + except yaml.YAMLError as exc: + raise ValueError(f"Invalid YAML in {filename!r}: {exc}") from exc + + def _walk(node: Any) -> None: + if isinstance(node, dict): + for key, value in node.items(): + if key in _BLOCKED_YAML_KEYS: + raise ValueError( + f"Blocked key {key!r} found in {filename!r}. " + f"The '{key}' field is not allowed in builder uploads " + "because it can execute arbitrary code." + ) + _walk(value) + elif isinstance(node, list): + for item in node: + _walk(item) + + for doc in docs: + _walk(doc) + + def _parse_upload_filename(filename: str | None) -> tuple[str, str]: + if not filename: + raise ValueError("Upload filename is missing.") + filename = _normalize_relative_path(filename) + if "/" not in filename: + raise ValueError(f"Invalid upload filename: {filename!r}") + app_name, rel_path = filename.split("/", 1) + if not app_name or not rel_path: + raise ValueError(f"Invalid upload filename: {filename!r}") + if rel_path.startswith("/"): + raise ValueError(f"Absolute upload path rejected: {filename!r}") + if _has_parent_reference(rel_path): + raise ValueError(f"Path traversal rejected: {filename!r}") + ext = os.path.splitext(rel_path)[1].lower() + if ext not in _ALLOWED_EXTENSIONS: + raise ValueError( + f"File type not allowed: {rel_path!r}" + f" (allowed: {', '.join(sorted(_ALLOWED_EXTENSIONS))})" + ) + return app_name, rel_path + + def _parse_file_path(file_path: str) -> str: + file_path = _normalize_relative_path(file_path) + if not file_path: + raise ValueError("file_path is missing.") + if file_path.startswith("/"): + raise ValueError(f"Absolute file_path rejected: {file_path!r}") + if _has_parent_reference(file_path): + raise ValueError(f"Path traversal rejected: {file_path!r}") + ext = os.path.splitext(file_path)[1].lower() + if ext not in _ALLOWED_EXTENSIONS: + raise ValueError( + f"File type not allowed: {file_path!r}" + f" (allowed: {', '.join(sorted(_ALLOWED_EXTENSIONS))})" + ) + return file_path + + def _resolve_under_dir(root_dir: Path, rel_path: str) -> Path: + file_path = root_dir / rel_path + resolved_root_dir = root_dir.resolve() + resolved_file_path = file_path.resolve() + if not resolved_file_path.is_relative_to(resolved_root_dir): + raise ValueError(f"Path escapes root_dir: {rel_path!r}") + return file_path + + def _get_tmp_agent_root(app_root: Path, app_name: str) -> Path: + tmp_agent_root = app_root / "tmp" / app_name + resolved_tmp_agent_root = tmp_agent_root.resolve() + if not resolved_tmp_agent_root.is_relative_to(app_root): + raise ValueError(f"Invalid tmp path for app: {app_name!r}") + return tmp_agent_root + + def copy_dir_contents(source_dir: Path, dest_dir: Path) -> None: + dest_dir.mkdir(parents=True, exist_ok=True) + for source_path in source_dir.iterdir(): + if source_path.name == "tmp": + continue + + dest_path = dest_dir / source_path.name + if source_path.is_dir(): + if dest_path.exists() and dest_path.is_file(): + dest_path.unlink() + shutil.copytree(source_path, dest_path, dirs_exist_ok=True) + elif source_path.is_file(): + if dest_path.exists() and dest_path.is_dir(): + shutil.rmtree(dest_path) + shutil.copy2(source_path, dest_path) + + def cleanup_tmp(app_name: str) -> bool: + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in cleanup_tmp: %s", exc) + return False + + try: + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + except ValueError as exc: + logger.exception("Error in cleanup_tmp: %s", exc) + return False + + try: + shutil.rmtree(tmp_agent_root) + except FileNotFoundError: + pass + except OSError as exc: + logger.exception("Error deleting tmp agent root: %s", exc) + return False + + tmp_dir = app_root / "tmp" + resolved_tmp_dir = tmp_dir.resolve() + if not resolved_tmp_dir.is_relative_to(app_root): + logger.error( + "Refusing to delete tmp outside app_root: %s", resolved_tmp_dir + ) + return False + + try: + tmp_dir.rmdir() + except OSError: + pass + + return True + + def ensure_tmp_exists(app_name: str) -> bool: + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + if not app_root.is_dir(): + return False + + try: + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + except ValueError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + if tmp_agent_root.exists(): + return True + + try: + tmp_agent_root.mkdir(parents=True, exist_ok=True) + copy_dir_contents(app_root, tmp_agent_root) + except OSError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + return True + + @app.post("/builder/save", response_model_exclude_none=True) + async def builder_build( + files: list[UploadFile], tmp: bool | None = False + ) -> bool: + try: + # Phase 1: parse filenames and read content into memory. + app_names: set[str] = set() + uploads: list[tuple[str, bytes]] = [] + for file in files: + app_name, rel_path = _parse_upload_filename(file.filename) + app_names.add(app_name) + content = await file.read() + uploads.append((rel_path, content)) + + if len(app_names) != 1: + logger.error( + "Exactly one app name is required, found: %s", + sorted(app_names), + ) + return False + + app_name = next(iter(app_names)) + + # Phase 2: validate every file *before* writing anything to disk. + for rel_path, content in uploads: + _check_yaml_for_blocked_keys(content, f"{app_name}/{rel_path}") + + # Phase 3: write validated files to disk. + if tmp: + app_root = _get_app_root(app_name) + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + tmp_agent_root.mkdir(parents=True, exist_ok=True) + + for rel_path, content in uploads: + destination_path = _resolve_under_dir(tmp_agent_root, rel_path) + destination_path.parent.mkdir(parents=True, exist_ok=True) + destination_path.write_bytes(content) + + return True + + app_root = _get_app_root(app_name) + app_root.mkdir(parents=True, exist_ok=True) + + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + if tmp_agent_root.is_dir(): + copy_dir_contents(tmp_agent_root, app_root) + + for rel_path, content in uploads: + destination_path = _resolve_under_dir(app_root, rel_path) + destination_path.parent.mkdir(parents=True, exist_ok=True) + destination_path.write_bytes(content) + + return cleanup_tmp(app_name) + except ValueError as exc: + logger.exception("Error in builder_build: %s", exc) + raise HTTPException(status_code=400, detail=str(exc)) + except OSError as exc: + logger.exception("Error in builder_build: %s", exc) + return False + + @app.post( + "/builder/app/{app_name}/cancel", response_model_exclude_none=True + ) + async def builder_cancel(app_name: str) -> bool: + return cleanup_tmp(app_name) + + @app.get( + "/builder/app/{app_name}", + response_model_exclude_none=True, + response_class=PlainTextResponse, + ) + async def get_agent_builder( + app_name: str, + file_path: str | None = None, + tmp: bool | None = False, + ): + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) + return "" + + agent_dir = app_root + if tmp: + if not ensure_tmp_exists(app_name): + return "" + agent_dir = app_root / "tmp" / app_name + + if not file_path: + rel_path = "root_agent.yaml" + else: + try: + rel_path = _parse_file_path(file_path) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) + return "" + + try: + agent_file_path = _resolve_under_dir(agent_dir, rel_path) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) + return "" + + if not agent_file_path.is_file(): + return "" + + return FileResponse( + path=agent_file_path, + media_type="application/x-yaml", + filename=file_path or f"{app_name}.yaml", + headers={"Cache-Control": "no-store"}, + ) if a2a and a2a_task_store is not None: from a2a.server.apps import A2AStarletteApplication @@ -696,5 +706,161 @@ async def _get_a2a_runner_async() -> Runner: except Exception as e: logger.error("Failed to setup A2A agent %s: %s", app_name, e) # Continue with other agents even if one fails + if gemini_enterprise_app_name: + if gemini_enterprise_app_name not in agent_loader.list_agents(): + raise ValueError( + f"App {gemini_enterprise_app_name} not found in dir: {agents_dir}" + ) + + import inspect + import json + + from google.adk.agents import Agent + import vertexai + from vertexai import agent_engines + + project = os.environ.get("GOOGLE_CLOUD_PROJECT", None) + location = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", + os.environ.get("GOOGLE_CLOUD_LOCATION", None), + ) + api_key = os.environ.get("GOOGLE_API_KEY", None) + if project: + vertexai.init(project=project, location=location) + elif api_key: + vertexai.init(api_key=api_key) + else: + raise ValueError( + "No GOOGLE_CLOUD_PROJECT or GOOGLE_API_KEY found in environment" + " variables." + ) + # The tmp agent will be replaced by the adk server's runner and services. + # It is specified here because it is a required argument to AdkApp. + adk_app = agent_engines.AdkApp(agent=Agent(name="tmp")) + adk_app._tmpl_attrs["runner"] = None + adk_app._tmpl_attrs["app_name"] = gemini_enterprise_app_name + adk_app._tmpl_attrs["session_service"] = session_service + adk_app._tmpl_attrs["memory_service"] = memory_service + adk_app._tmpl_attrs["artifact_service"] = artifact_service + + def _encode_chunk_to_json(chunk: Any) -> str | None: + """Encodes a chunk to a JSON string with a newline.""" + try: + json_chunk = jsonable_encoder(chunk) + return f"{json.dumps(json_chunk)}\n" + except Exception: + logging.exception("Failed to encode chunk") + return None + + async def json_generator(output: AsyncIterator[Any]) -> AsyncIterator[str]: + async for chunk in output: + encoded_chunk = _encode_chunk_to_json(chunk) + if encoded_chunk is None: + break + yield encoded_chunk + + async def _invoke_callable_or_raise( + invocation_callable: Callable[..., Any], + invocation_payload: dict[str, Any], + ) -> Any: + if inspect.iscoroutinefunction(invocation_callable): + return await invocation_callable(**invocation_payload) + elif inspect.isasyncgenfunction(invocation_callable): + return invocation_callable(**invocation_payload) + else: + return await run_in_threadpool( + invocation_callable, **invocation_payload + ) + + # Implement a FastAPI middleware to extract and attach OpenTelemetry trace + # context from a custom Google-Agent-Engine-Traceparent header in incoming + # requests. This enables distributed tracing. + tracer_provider = trace.get_tracer_provider() + if isinstance(tracer_provider, TracerProvider): + tracer_provider.add_span_processor(TopSpanProcessor()) + else: + logging.warning( + "OpenTelemetry tracing is not enabled. Please set the" + " `OTEL_PYTHON_TRACER_PROVIDER` environment variable to enable" + " tracing." + ) + + @app.middleware("http") + async def context_propagation( + request: Request, call_next: Callable[[Request], Awaitable[Any]] + ) -> Any: + ctx = get_propagated_context(request) + token = context.attach(ctx) + try: + response = await call_next(request) + return response + finally: + context.detach(token) + + @app.post( + "/api/reasoning_engine", + response_model_exclude_none=True, + response_class=JSONResponse, + ) + async def query(request: _QueryRequest): + if not adk_app._tmpl_attrs.get("runner"): + adk_app._tmpl_attrs["runner"] = await adk_web_server.get_runner_async( + app_name=gemini_enterprise_app_name + ) + if request.class_method is None: + raise HTTPException( + status_code=400, detail="class_method cannot be None" + ) + method = getattr(adk_app, request.class_method) + output = await _invoke_callable_or_raise(method, request.input or {}) + + try: + json_serialized_content = jsonable_encoder({"output": output}) + except ValueError as encoding_error: + logging.exception( + "FastAPI could not JSON-encode the response from invocation method" + " %s. Error: %s. Invocation method's original response: %r", + request.class_method, + encoding_error, + output, + ) + raise + return JSONResponse(content=json_serialized_content) + + @app.post( + "/api/stream_reasoning_engine", + response_model_exclude_none=True, + response_class=StreamingResponse, + ) + async def stream_query(request: _QueryRequest): + if not adk_app._tmpl_attrs.get("runner"): + adk_app._tmpl_attrs["runner"] = await adk_web_server.get_runner_async( + app_name=gemini_enterprise_app_name + ) + if request.class_method is None: + raise HTTPException( + status_code=400, detail="class_method cannot be None" + ) + method = getattr(adk_app, request.class_method) + output = await _invoke_callable_or_raise(method, request.input or {}) + + if inspect.isgenerator(output): + + async def _aiter_from_iter(iterator): + while True: + try: + chunk = await run_in_threadpool(next, iterator) + yield chunk + except StopIteration: + break + + content_iter = _aiter_from_iter(output) + else: + content_iter = output + + return StreamingResponse( + content=json_generator(content_iter), + media_type="application/json", + ) return app diff --git a/src/google/adk/cli/utils/_telemetry.py b/src/google/adk/cli/utils/_telemetry.py new file mode 100644 index 0000000000..070cb45526 --- /dev/null +++ b/src/google/adk/cli/utils/_telemetry.py @@ -0,0 +1,106 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Mapping +from typing import Optional + +import fastapi +from opentelemetry import baggage +from opentelemetry import context +from opentelemetry.sdk import trace +from opentelemetry.trace.propagation import tracecontext + +_GOOGLE_AE_TRACEPARENT_HEADER = "Google-Agent-Engine-Traceparent" +_TRACEPARENT_BAGGAGE_KEY = "traceparent" +_GOOGLE_TRACEPARENT_HEADER = "traceparent" +_GOOGLE_TRACEPARENT_BAGGAGE_KEY = "google_traceparent" +_GOOGLE_TRACEPARENT_SUPPORT_ATTRIBUTE_KEY = "supportID" + + +def get_propagated_context(request: fastapi.Request) -> context.Context: + """Propagates context from the request headers.""" + ctx = context.get_current() + + if _GOOGLE_TRACEPARENT_HEADER in request.headers: + original_traceparent = request.headers[_GOOGLE_TRACEPARENT_HEADER] + ctx = baggage.set_baggage( + _GOOGLE_TRACEPARENT_BAGGAGE_KEY, + original_traceparent, + context=ctx, + ) + + if _GOOGLE_AE_TRACEPARENT_HEADER in request.headers: + carrier = {"traceparent": request.headers[_GOOGLE_AE_TRACEPARENT_HEADER]} + ctx = baggage.set_baggage( + _TRACEPARENT_BAGGAGE_KEY, + request.headers[_GOOGLE_AE_TRACEPARENT_HEADER], + context=ctx, + ) + ctx = tracecontext.TraceContextTextMapPropagator().extract( + carrier=carrier, context=ctx + ) + + return ctx + + +class TopSpanProcessor(trace.SpanProcessor): + """Top span processor.""" + + def on_start( + self, span: trace.Span, parent_context: Optional[context.Context] = None + ): + """Adds support ID to the top span.""" + baggage_items = baggage.get_all(context=parent_context) + if self._is_top_span(span, baggage_items) and ( + baggage_trace_header := baggage_items.get( + _GOOGLE_TRACEPARENT_BAGGAGE_KEY + ) + ): + span.set_attribute( + _GOOGLE_TRACEPARENT_SUPPORT_ATTRIBUTE_KEY, baggage_trace_header + ) + + def on_end(self, span: trace.ReadableSpan) -> None: + pass + + def shutdown(self) -> None: + pass + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + def _is_top_span( + self, span: trace.Span, baggage_items: Mapping[str, object] + ) -> bool: + """Returns true if the span is a top span. + + Args: + span: The span to check. + baggage_items: The baggage items that carry the context. + + Top span (e.g. "Invocation" span) is defined as the first span generated in + trace generation. + Top span could have an empty parent or the parent could be the span + provided by traceparent propagation. + """ + if span.parent is None or span.parent.span_id == 0: + return True + if _TRACEPARENT_BAGGAGE_KEY in baggage_items: + parent_id_hex = str(baggage_items[_TRACEPARENT_BAGGAGE_KEY]).split("-")[2] + parent_id_int = int(parent_id_hex, 16) + if span.parent.span_id == parent_id_int: + return True + return False diff --git a/src/google/adk/telemetry/_agent_engine.py b/src/google/adk/telemetry/_agent_engine.py new file mode 100644 index 0000000000..070cb45526 --- /dev/null +++ b/src/google/adk/telemetry/_agent_engine.py @@ -0,0 +1,106 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Mapping +from typing import Optional + +import fastapi +from opentelemetry import baggage +from opentelemetry import context +from opentelemetry.sdk import trace +from opentelemetry.trace.propagation import tracecontext + +_GOOGLE_AE_TRACEPARENT_HEADER = "Google-Agent-Engine-Traceparent" +_TRACEPARENT_BAGGAGE_KEY = "traceparent" +_GOOGLE_TRACEPARENT_HEADER = "traceparent" +_GOOGLE_TRACEPARENT_BAGGAGE_KEY = "google_traceparent" +_GOOGLE_TRACEPARENT_SUPPORT_ATTRIBUTE_KEY = "supportID" + + +def get_propagated_context(request: fastapi.Request) -> context.Context: + """Propagates context from the request headers.""" + ctx = context.get_current() + + if _GOOGLE_TRACEPARENT_HEADER in request.headers: + original_traceparent = request.headers[_GOOGLE_TRACEPARENT_HEADER] + ctx = baggage.set_baggage( + _GOOGLE_TRACEPARENT_BAGGAGE_KEY, + original_traceparent, + context=ctx, + ) + + if _GOOGLE_AE_TRACEPARENT_HEADER in request.headers: + carrier = {"traceparent": request.headers[_GOOGLE_AE_TRACEPARENT_HEADER]} + ctx = baggage.set_baggage( + _TRACEPARENT_BAGGAGE_KEY, + request.headers[_GOOGLE_AE_TRACEPARENT_HEADER], + context=ctx, + ) + ctx = tracecontext.TraceContextTextMapPropagator().extract( + carrier=carrier, context=ctx + ) + + return ctx + + +class TopSpanProcessor(trace.SpanProcessor): + """Top span processor.""" + + def on_start( + self, span: trace.Span, parent_context: Optional[context.Context] = None + ): + """Adds support ID to the top span.""" + baggage_items = baggage.get_all(context=parent_context) + if self._is_top_span(span, baggage_items) and ( + baggage_trace_header := baggage_items.get( + _GOOGLE_TRACEPARENT_BAGGAGE_KEY + ) + ): + span.set_attribute( + _GOOGLE_TRACEPARENT_SUPPORT_ATTRIBUTE_KEY, baggage_trace_header + ) + + def on_end(self, span: trace.ReadableSpan) -> None: + pass + + def shutdown(self) -> None: + pass + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + def _is_top_span( + self, span: trace.Span, baggage_items: Mapping[str, object] + ) -> bool: + """Returns true if the span is a top span. + + Args: + span: The span to check. + baggage_items: The baggage items that carry the context. + + Top span (e.g. "Invocation" span) is defined as the first span generated in + trace generation. + Top span could have an empty parent or the parent could be the span + provided by traceparent propagation. + """ + if span.parent is None or span.parent.span_id == 0: + return True + if _TRACEPARENT_BAGGAGE_KEY in baggage_items: + parent_id_hex = str(baggage_items[_TRACEPARENT_BAGGAGE_KEY]).split("-")[2] + parent_id_int = int(parent_id_hex, 16) + if span.parent.span_id == parent_id_int: + return True + return False diff --git a/src/google/adk/telemetry/google_cloud.py b/src/google/adk/telemetry/google_cloud.py index c34cdba90c..265ef67173 100644 --- a/src/google/adk/telemetry/google_cloud.py +++ b/src/google/adk/telemetry/google_cloud.py @@ -17,16 +17,18 @@ import enum import logging import os -from typing import Any +import sys from typing import Callable from typing import cast from typing import Optional from typing import TYPE_CHECKING +import uuid import google.auth from google.auth.transport import mtls from opentelemetry.sdk._logs import LogRecordProcessor from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.sdk._logs.export import SimpleLogRecordProcessor from opentelemetry.sdk.metrics.export import MetricReader from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import OTELResourceDetector @@ -39,7 +41,7 @@ if TYPE_CHECKING: from google.auth.credentials import Credentials -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) _GCP_LOG_NAME_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_DEFAULT_LOG_NAME' _DEFAULT_LOG_NAME = 'adk-otel' @@ -82,8 +84,8 @@ def get_gcp_exporters( if not project_id: logger.warning( - 'Cannot determine GCP Project. OTel GCP Exporters cannot be set up.' - ' Please make sure to log into correct GCP Project.' + "Cannot determine GCP Project. OTel GCP Exporters cannot be set up." + " Please make sure to log into correct GCP Project." ) return OTelHooks() @@ -100,7 +102,10 @@ def get_gcp_exporters( log_record_processors: list[LogRecordProcessor] = [] if enable_cloud_logging: - exporter = _get_gcp_logs_exporter(project_id) + exporter = _get_gcp_logs_exporter( + project_id=project_id, + credentials=credentials, + ) if exporter: log_record_processors.append(exporter) @@ -131,10 +136,24 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor: else: endpoint = _DEFAULT_TELEMETRY_TRACES_ENPOINT + headers = None + if os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"): + from google.cloud.aiplatform import version as aip_version + try: + from opentelemetry.exporter.otlp.proto.http import version as otlp_http_version + except (ImportError, AttributeError): + otlp_http_version = None + + user_agent = f"Vertex-Agent-Engine/{aip_version.__version__}" + if otlp_http_version: + user_agent += f" OTel-OTLP-Exporter-Python/{otlp_http_version.__version__}" + headers = {"User-Agent": user_agent} + return BatchSpanProcessor( OTLPSpanExporter( session=session, endpoint=endpoint, + headers=headers, ) ) @@ -148,7 +167,16 @@ def _get_gcp_metrics_exporter(project_id: str) -> MetricReader: ) -def _get_gcp_logs_exporter(project_id: str) -> LogRecordProcessor: +def _get_gcp_logs_exporter( + project_id: str, + credentials: Optional["Credentials"] = None, +) -> LogRecordProcessor: + if os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID"): + return _get_agent_engine_logs_exporter( + credentials=credentials, + project_id=project_id, + ) + from opentelemetry.exporter.cloud_logging import CloudLoggingExporter default_log_name = os.environ.get( @@ -171,8 +199,29 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource: project_id: project id to fill out as `gcp.project_id` on the OTEL resource. This may be overwritten by OTELResourceDetector, if `gcp.project_id` is present in `OTEL_RESOURCE_ATTRIBUTES` env var. """ + agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "") + if agent_engine_id: + resource = Resource.create( + attributes={ + "gcp.project_id": project_id, + "cloud.account.id": project_id, + "cloud.provider": "gcp", + "cloud.platform": "gcp.agent_engine", + "service.name": agent_engine_id, + "service.version": os.getenv( + "GOOGLE_CLOUD_AGENT_ENGINE_RUNTIME_REVISION_ID", "" + ), + "service.instance.id": f"{uuid.uuid4().hex}-{os.getpid()}", + "cloud.region": ( + os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", "") + or os.getenv("GOOGLE_CLOUD_LOCATION", "") + ), + } + ).merge(OTELResourceDetector().detect()) + return resource + resource = Resource( - attributes={'gcp.project_id': project_id} + attributes={"gcp.project_id": project_id} if project_id is not None else {} ) @@ -185,8 +234,8 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource: ) except ImportError: logger.warning( - 'Cloud not import opentelemetry.resourcedetector.gcp_resource_detector' - ' GCE, GKE or CloudRun related resource attributes may be missing' + "Cloud not import opentelemetry.resourcedetector.gcp_resource_detector" + " GCE, GKE or CloudRun related resource attributes may be missing" ) return resource @@ -248,3 +297,75 @@ def _use_client_cert_effective() -> bool: ' either `true` or `false`' ) return use_client_cert_str == 'true' + + +def _get_agent_engine_logs_exporter( + *, + credentials: "Credentials", + project_id: str, +): + """Configures logging for Agent Engine. + + Args: + credentials: Credentials to use for export calls. + project_id: Project to which to write logs. + """ + try: + from google.cloud.logging_v2.services import logging_service_v2 + from google.cloud.logging_v2.services.logging_service_v2.transports import grpc + from opentelemetry.exporter import cloud_logging + except (ImportError, AttributeError): + logging.warning( + "%s is not installed. Please call 'pip install %s'.", + "opentelemetry-exporter-gcp-logging", + "opentelemetry-exporter-gcp-logging", + ) + logging.warning( + "proceeding with logging disabled because not all packages for" + " logging have been installed" + ) + return + + if "gen_ai_latest_experimental" in os.getenv( + "OTEL_SEMCONV_STABILITY_OPT_IN", "" + ).split(","): + # Specify credentials to avoid expensive call to `google.auth.default()` + channel = grpc.LoggingServiceV2GrpcTransport.create_channel( + credentials=credentials, + # pylint: disable-next=protected-access + options=cloud_logging._OPTIONS, + ) + return BatchLogRecordProcessor( + cloud_logging.CloudLoggingExporter( + client=logging_service_v2.LoggingServiceV2Client( + transport=grpc.LoggingServiceV2GrpcTransport( + credentials=credentials, + channel=channel, + ), + ), + project_id=project_id, + default_log_name=os.getenv( + "GCP_DEFAULT_LOG_NAME", "adk-on-agent-engine" + ), + ), + ) + else: + + class _SimpleLogRecordProcessor(SimpleLogRecordProcessor): + + def force_flush( + self, timeout_millis: int = 30000 + ) -> bool: # pylint: disable=no-self-use + _ = sys.stdout.flush() + _ = sys.stderr.flush() + return super().force_flush() + + return _SimpleLogRecordProcessor( + cloud_logging.CloudLoggingExporter( + project_id=project_id, + default_log_name=os.getenv( + "GCP_DEFAULT_LOG_NAME", "adk-on-agent-engine" + ), + structured_json_file=sys.stdout, + ), + ) diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index a80644a9b6..023d3c78d0 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -22,6 +22,7 @@ from typing import Any from typing import Optional from unittest.mock import AsyncMock +from unittest.mock import call from unittest.mock import MagicMock from unittest.mock import patch @@ -112,7 +113,7 @@ def _event_state_delta(state_delta: dict[str, Any]): # Define mocked async generator functions for the Runner -async def dummy_run_live(self, session, live_request_queue, **kwargs): +async def dummy_run_live(self, session, live_request_queue): yield _event_1() await asyncio.sleep(0) @@ -891,6 +892,67 @@ def test_app_with_a2a( yield client +@pytest.fixture +def test_app_with_gemini_enterprise( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + monkeypatch, +): + """Create a TestClient with gemini_enterprise_app_name set.""" + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") + mock_agent_loader.list_agents = MagicMock( + return_value=["test_app", "gemini_app"] + ) + + mock_adk_app_instance = MagicMock() + mock_adk_app_instance._tmpl_attrs = {} + + async def my_method_impl(**kwargs): + return {"result": "success", "kwargs": kwargs} + + mock_adk_app_instance.my_method = my_method_impl + + async def my_stream_method_impl(**kwargs): + yield {"chunk": 1, "kwargs": kwargs} + await asyncio.sleep(0) + yield {"chunk": 2, "kwargs": kwargs} + + mock_adk_app_instance.my_stream_method = my_stream_method_impl + + with ( + patch("vertexai.init", new_callable=MagicMock) as mock_vertexai_init, + patch( + "vertexai.agent_engines.AdkApp", return_value=mock_adk_app_instance + ) as mock_adk_app_cls, + patch("google.adk.agents.Agent", new_callable=MagicMock), + patch( + "google.adk.cli.utils._telemetry.TopSpanProcessor", + new_callable=MagicMock, + ), + patch( + "google.adk.cli.utils._telemetry.get_propagated_context", + new_callable=MagicMock, + ), + ): + client = _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + gemini_enterprise_app_name="gemini_app", + ) + client.mock_vertexai_init = mock_vertexai_init + client.mock_adk_app_cls = mock_adk_app_cls + client.mock_adk_app_instance = mock_adk_app_instance + yield client + + ################################################# # Test Cases ################################################# @@ -1683,10 +1745,10 @@ def test_get_eval_set_result_not_found(test_app): assert response.status_code == 404 -def test_list_metrics_info(builder_test_client): +def test_list_metrics_info(test_app): """Test listing metrics info.""" - url = "/dev/apps/test_app/metrics-info" - response = builder_test_client.get(url) + url = "/apps/test_app/metrics-info" + response = test_app.get(url) # Verify the response assert response.status_code == 200 @@ -1706,7 +1768,7 @@ def test_debug_trace(test_app): """Test the debug trace endpoint.""" # This test will likely return 404 since we haven't set up trace data, # but it tests that the endpoint exists and handles missing traces correctly. - url = "/dev/apps/test_app/debug/trace/nonexistent-event" + url = "/debug/trace/nonexistent-event" response = test_app.get(url) # Verify we get a 404 for a nonexistent trace @@ -1721,6 +1783,56 @@ def test_openapi_json_schema_accessible(test_app): logger.info("OpenAPI /openapi.json endpoint is accessible") +def test_get_event_graph_returns_dot_src_for_app_agent(): + """Ensure graph endpoint unwraps App instances before building the graph.""" + from google.adk.cli.adk_web_server import AdkWebServer + + root_agent = DummyAgent(name="dummy_agent") + app_agent = App(name="test_app", root_agent=root_agent) + + class Loader: + + def load_agent(self, app_name): + return app_agent + + def list_agents(self): + return [app_agent.name] + + session_service = AsyncMock() + session = Session( + id="session_id", + app_name="test_app", + user_id="user", + state={}, + events=[Event(author="dummy_agent")], + ) + event_id = session.events[0].id + session_service.get_session.return_value = session + + adk_web_server = AdkWebServer( + agent_loader=Loader(), + session_service=session_service, + memory_service=MagicMock(), + artifact_service=MagicMock(), + credential_service=MagicMock(), + eval_sets_manager=MagicMock(), + eval_set_results_manager=MagicMock(), + agents_dir=".", + ) + + fast_api_app = adk_web_server.get_fast_api_app( + setup_observer=lambda _observer, _server: None, + tear_down_observer=lambda _observer, _server: None, + ) + + client = TestClient(fast_api_app) + response = client.get( + f"/apps/test_app/users/user/sessions/session_id/events/{event_id}/graph" + ) + assert response.status_code == 200 + assert "dotSrc" in response.json() + + def test_a2a_agent_discovery(test_app_with_a2a): """Test that A2A agents are properly discovered and configured.""" # This test mainly verifies that the A2A setup doesn't break the app @@ -2082,14 +2194,12 @@ def test_builder_final_save_preserves_files_and_cleans_tmp( ("app/sub_agent.yaml", b"name: sub\n", "application/x-yaml"), ), ] - response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", files=files - ) + response = builder_test_client.post("/builder/save?tmp=true", files=files) assert response.status_code == 200 assert response.json() is True response = builder_test_client.post( - "/dev/apps/app/builder/save", + "/builder/save", files=[( "files", ( @@ -2110,7 +2220,7 @@ def test_builder_final_save_preserves_files_and_cleans_tmp( def test_builder_save_rejects_cross_origin_post(builder_test_client, tmp_path): response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", headers={"origin": "https://evil.com"}, files=[( "files", @@ -2125,7 +2235,7 @@ def test_builder_save_rejects_cross_origin_post(builder_test_client, tmp_path): def test_builder_save_allows_same_origin_post(builder_test_client, tmp_path): response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", headers={"origin": "http://testserver"}, files=[( "files", @@ -2140,7 +2250,7 @@ def test_builder_save_allows_same_origin_post(builder_test_client, tmp_path): def test_builder_get_allows_cross_origin_get(builder_test_client): response = builder_test_client.get( - "/dev/apps/missing/builder?tmp=true", + "/builder/app/missing?tmp=true", headers={"origin": "https://evil.com"}, ) @@ -2153,12 +2263,12 @@ def test_builder_cancel_deletes_tmp_idempotent(builder_test_client, tmp_path): tmp_agent_root.mkdir(parents=True, exist_ok=True) (tmp_agent_root / "root_agent.yaml").write_text("name: app\n") - response = builder_test_client.post("/dev/apps/app/builder/cancel") + response = builder_test_client.post("/builder/app/app/cancel") assert response.status_code == 200 assert response.json() is True assert not (tmp_path / "app" / "tmp").exists() - response = builder_test_client.post("/dev/apps/app/builder/cancel") + response = builder_test_client.post("/builder/app/app/cancel") assert response.status_code == 200 assert response.json() is True assert not (tmp_path / "app" / "tmp").exists() @@ -2173,7 +2283,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): (nested_dir / "nested.yaml").write_text("nested: true\n") assert not (app_root / "tmp").exists() - response = builder_test_client.get("/dev/apps/app/builder?tmp=true") + response = builder_test_client.get("/builder/app/app?tmp=true") assert response.status_code == 200 assert response.text == "name: app\n" @@ -2182,7 +2292,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): assert (tmp_agent_root / "nested" / "nested.yaml").is_file() response = builder_test_client.get( - "/dev/apps/app/builder?tmp=true&file_path=nested/nested.yaml" + "/builder/app/app?tmp=true&file_path=nested/nested.yaml" ) assert response.status_code == 200 assert response.text == "nested: true\n" @@ -2191,7 +2301,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): def test_builder_get_tmp_true_missing_app_returns_empty( builder_test_client, tmp_path ): - response = builder_test_client.get("/dev/apps/missing/builder?tmp=true") + response = builder_test_client.get("/builder/app/missing?tmp=true") assert response.status_code == 200 assert response.text == "" assert not (tmp_path / "missing").exists() @@ -2199,7 +2309,7 @@ def test_builder_get_tmp_true_missing_app_returns_empty( def test_builder_save_rejects_traversal(builder_test_client, tmp_path): response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", files=[( "files", ("app/../escape.yaml", b"nope\n", "application/x-yaml"), @@ -2213,7 +2323,7 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): def test_builder_save_rejects_py_files(builder_test_client, tmp_path): """Uploading .py files via /builder/save is rejected.""" response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", files=[( "files", ("app/agent.py", b"import os\nos.system('id')\n", "text/plain"), @@ -2235,7 +2345,7 @@ def test_builder_save_rejects_non_yaml_extensions( (".pth", b"import os"), ]: response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", files=[( "files", (f"app/file{ext}", content, "application/octet-stream"), @@ -2247,7 +2357,7 @@ def test_builder_save_rejects_non_yaml_extensions( def test_builder_save_allows_yaml_files(builder_test_client, tmp_path): """Uploading .yaml and .yml files is allowed.""" response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", files=[( "files", ("app/root_agent.yaml", b"name: app\n", "application/x-yaml"), @@ -2257,7 +2367,7 @@ def test_builder_save_allows_yaml_files(builder_test_client, tmp_path): assert response.json() is True response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", files=[( "files", ("app/sub_agent.yml", b"name: sub\n", "application/x-yaml"), @@ -2275,7 +2385,7 @@ def test_builder_save_rejects_args_key(builder_test_client, tmp_path): key: value """ response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", files=[( "files", ("app/root_agent.yaml", yaml_with_args, "application/x-yaml"), @@ -2295,7 +2405,7 @@ def test_builder_save_rejects_nested_args_key(builder_test_client, tmp_path): param: value """ response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", files=[( "files", ("app/root_agent.yaml", yaml_with_nested_args, "application/x-yaml"), @@ -2306,7 +2416,7 @@ def test_builder_save_rejects_nested_args_key(builder_test_client, tmp_path): def test_builder_get_rejects_non_yaml_file_paths(builder_test_client, tmp_path): - """GET /dev/apps/{app_name}/builder?file_path=... rejects non-YAML extensions.""" + """GET /builder/app/{app_name}?file_path=... rejects non-YAML extensions.""" app_root = tmp_path / "app" app_root.mkdir(parents=True, exist_ok=True) (app_root / ".env").write_text("SECRET=supersecret\n") @@ -2315,26 +2425,26 @@ def test_builder_get_rejects_non_yaml_file_paths(builder_test_client, tmp_path): for file_path in [".env", "agent.py", "config.json"]: response = builder_test_client.get( - f"/dev/apps/app/builder?file_path={file_path}" + f"/builder/app/app?file_path={file_path}" ) assert response.status_code == 200, f"Expected 200 for {file_path}" assert response.text == "", f"Expected empty response for {file_path}" def test_builder_get_allows_yaml_file_paths(builder_test_client, tmp_path): - """GET /dev/apps/{app_name}/builder?file_path=... allows YAML extensions.""" + """GET /builder/app/{app_name}?file_path=... allows YAML extensions.""" app_root = tmp_path / "app" app_root.mkdir(parents=True, exist_ok=True) (app_root / "sub_agent.yaml").write_text("name: sub\n") (app_root / "tool.yml").write_text("name: tool\n") response = builder_test_client.get( - "/dev/apps/app/builder?file_path=sub_agent.yaml" + "/builder/app/app?file_path=sub_agent.yaml" ) assert response.status_code == 200 assert response.text == "name: sub\n" - response = builder_test_client.get("/dev/apps/app/builder?file_path=tool.yml") + response = builder_test_client.get("/builder/app/app?file_path=tool.yml") assert response.status_code == 200 assert response.text == "name: tool\n" @@ -2357,28 +2467,28 @@ def test_builder_endpoints_not_registered_without_web( mock_eval_set_results_manager, web=False, ) - # /dev/apps/app/builder/save should return 404/405, not 200 + # /builder/save should return 404/405, not 200 response = client.post( - "/dev/apps/app/builder/save", + "/builder/save", files=[ ("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml")) ], ) assert response.status_code in (404, 405) - # /dev/apps/{name}/builder/cancel should also be absent - response = client.post("/dev/apps/app/builder/cancel") + # /builder/app/{name}/cancel should also be absent + response = client.post("/builder/app/app/cancel") assert response.status_code in (404, 405) - # /dev/apps/{name}/builder GET should also be absent - response = client.get("/dev/apps/app/builder") + # /builder/app/{name} GET should also be absent + response = client.get("/builder/app/app") assert response.status_code in (404, 405) def test_builder_endpoints_registered_with_web(builder_test_client): """Builder endpoints are available when web=True.""" response = builder_test_client.post( - "/dev/apps/app/builder/save?tmp=true", + "/builder/save?tmp=true", files=[ ("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml")) ], @@ -2630,7 +2740,12 @@ async def run_async_capture( assert captured_visual_builder_values.get("yaml_app_after_sleep") == True -def test_default_app_name_middleware_and_resolution( +################################################# +# Gemini Enterprise Tests +################################################# + + +def test_gemini_app_not_found_raises( mock_session_service, mock_artifact_service, mock_memory_service, @@ -2639,67 +2754,62 @@ def test_default_app_name_middleware_and_resolution( mock_eval_set_results_manager, monkeypatch, ): - """Test that when ADK_DEFAULT_APP_NAME is set, path rewriting works for get_session and run.""" - # Set environment variable - monkeypatch.setenv("ADK_DEFAULT_APP_NAME", "test_app") - - test_app = _create_test_client( - mock_session_service, - mock_artifact_service, - mock_memory_service, - mock_agent_loader, - mock_eval_sets_manager, - mock_eval_set_results_manager, - ) - - # Create session for test_app - async def setup_session(): - await mock_session_service.create_session( - app_name="test_app", - user_id="test_user", - session_id="test_session", - state={}, + """Test get_fast_api_app raises ValueError if gemini_enterprise_app_name not found.""" + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") + mock_agent_loader.list_agents = MagicMock(return_value=["test_app"]) + with pytest.raises(ValueError, match="not found in dir"): + _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + gemini_enterprise_app_name="nonexistent_app", ) - asyncio.run(setup_session()) - # 1. Test path rewriting for GET /users/{user_id}/sessions/{session_id} - response = test_app.get("/users/test_user/sessions/test_session") - assert response.status_code == 200 - assert response.json()["id"] == "test_session" - - # 2. Test app_name omission in /run request payload - payload = { - "user_id": "test_user", - "session_id": "test_session", - "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, - } - response = test_app.post("/run", json=payload) - assert response.status_code == 200 - assert isinstance(response.json(), list) - - -def test_default_app_name_not_set_raises_error(test_app, monkeypatch): - """Test that omitting app_name when ADK_DEFAULT_APP_NAME is not set raises 400/404.""" - # Make sure environment variable is NOT set - monkeypatch.delenv("ADK_DEFAULT_APP_NAME", raising=False) - - # 1. Accessing /users/{user_id}/sessions/{session_id} should return 404 because no rewrite happened - response = test_app.get("/users/test_user/sessions/test_session") - assert response.status_code == 404 - - # 2. Accessing /run with omitted app_name should return 400 - payload = { - "user_id": "test_user", - "session_id": "test_session", - "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, - } - response = test_app.post("/run", json=payload) - assert response.status_code == 400 - assert "app_name is required" in response.json()["detail"] +def test_gemini_missing_credentials_raises( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + monkeypatch, +): + """Test get_fast_api_app raises ValueError if no credentials are provided.""" + monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + mock_agent_loader.list_agents = MagicMock(return_value=["gemini_app"]) + with pytest.raises( + ValueError, match="No GOOGLE_CLOUD_PROJECT or GOOGLE_API_KEY" + ): + with ( + patch("vertexai.init"), + patch("vertexai.agent_engines.AdkApp"), + patch("google.adk.agents.Agent"), + patch( + "google.adk.cli.utils._telemetry.TopSpanProcessor", + new_callable=MagicMock, + ), + patch( + "google.adk.cli.utils._telemetry.get_propagated_context", + new_callable=MagicMock, + ), + ): + _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + gemini_enterprise_app_name="gemini_app", + ) -def test_run_live_websocket_default_app_name( +def test_gemini_init_with_project_id( mock_session_service, mock_artifact_service, mock_memory_service, @@ -2708,49 +2818,124 @@ def test_run_live_websocket_default_app_name( mock_eval_set_results_manager, monkeypatch, ): - """Test that /run_live websocket endpoint resolves app_name using ADK_DEFAULT_APP_NAME.""" - monkeypatch.setenv("ADK_DEFAULT_APP_NAME", "test_app") + """Test vertexai.init is called with project_id.""" + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") + monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "test-location") + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + mock_agent_loader.list_agents = MagicMock(return_value=["gemini_app"]) + with ( + patch("vertexai.init") as mock_init, + patch("vertexai.agent_engines.AdkApp"), + patch("google.adk.agents.Agent"), + patch( + "google.adk.cli.utils._telemetry.TopSpanProcessor", + new_callable=MagicMock, + ), + patch( + "google.adk.cli.utils._telemetry.get_propagated_context", + new_callable=MagicMock, + ), + ): + _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + gemini_enterprise_app_name="gemini_app", + ) + mock_init.assert_called_once_with( + project="test-project", + location="test-location", + ) - test_app = _create_test_client( - mock_session_service, - mock_artifact_service, - mock_memory_service, - mock_agent_loader, - mock_eval_sets_manager, - mock_eval_set_results_manager, - ) - async def setup_session(): - await mock_session_service.create_session( - app_name="test_app", - user_id="user", - session_id="session", - state={}, +def test_gemini_init_with_api_key( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + monkeypatch, +): + """Test vertexai.init is called with api_key.""" + monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + mock_agent_loader.list_agents = MagicMock(return_value=["gemini_app"]) + with ( + patch("vertexai.init") as mock_init, + patch("vertexai.agent_engines.AdkApp"), + patch("google.adk.agents.Agent"), + patch( + "google.adk.cli.utils._telemetry.TopSpanProcessor", + new_callable=MagicMock, + ), + patch( + "google.adk.cli.utils._telemetry.get_propagated_context", + new_callable=MagicMock, + ), + ): + _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + gemini_enterprise_app_name="gemini_app", ) + mock_init.assert_called_once_with(api_key="test-api-key") - asyncio.run(setup_session()) - - url = "/run_live?user_id=user&session_id=session&modalities=AUDIO" - with test_app.websocket_connect(url) as ws: - data = ws.receive_json() - assert data["author"] == "dummy agent" +def test_gemini_reasoning_engine_success(test_app_with_gemini_enterprise): + """Test POST /api/reasoning_engine success case.""" + response = test_app_with_gemini_enterprise.post( + "/api/reasoning_engine", + json={"class_method": "my_method", "input": {"arg1": 1}}, + ) + assert response.status_code == 200 + assert response.json() == { + "output": {"result": "success", "kwargs": {"arg1": 1}} + } -def test_run_live_websocket_missing_app_name_raises_error( - test_app, monkeypatch +def test_gemini_reasoning_engine_missing_class_method( + test_app_with_gemini_enterprise, ): - """Test that /run_live websocket connection fails when app_name and ADK_DEFAULT_APP_NAME are both missing.""" - from fastapi.websockets import WebSocketDisconnect + """Test POST /api/reasoning_engine with missing class_method.""" + response = test_app_with_gemini_enterprise.post( + "/api/reasoning_engine", + json={"input": {"arg1": 1}}, + ) + assert response.status_code == 400 + - monkeypatch.delenv("ADK_DEFAULT_APP_NAME", raising=False) +def test_gemini_stream_reasoning_engine_success( + test_app_with_gemini_enterprise, +): + """Test POST /api/stream_reasoning_engine success case.""" + response = test_app_with_gemini_enterprise.post( + "/api/stream_reasoning_engine", + json={"class_method": "my_stream_method", "input": {"arg1": 1}}, + ) + assert response.status_code == 200 + lines = response.text.strip().split("\n") + assert len(lines) == 2 + assert json.loads(lines[0]) == {"chunk": 1, "kwargs": {"arg1": 1}} + assert json.loads(lines[1]) == {"chunk": 2, "kwargs": {"arg1": 1}} - url = "/run_live?user_id=user&session_id=session&modalities=AUDIO" - with pytest.raises(WebSocketDisconnect) as exc_info: - with test_app.websocket_connect(url) as ws: - ws.receive_json() - assert exc_info.value.code == 1008 +def test_gemini_stream_reasoning_engine_missing_class_method( + test_app_with_gemini_enterprise, +): + """Test POST /api/stream_reasoning_engine with missing class_method.""" + response = test_app_with_gemini_enterprise.post( + "/api/stream_reasoning_engine", + json={"input": {"arg1": 1}}, + ) + assert response.status_code == 400 if __name__ == "__main__": diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index f2b32773f4..b70c59dcaa 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -30,10 +30,11 @@ from unittest import mock import click +from click.testing import CliRunner +from google.adk.cli import cli_deploy +from google.adk.cli import cli_tools_click import pytest -import src.google.adk.cli.cli_deploy as cli_deploy - # Helpers class _Recorder: @@ -233,21 +234,6 @@ def test_get_service_option_by_adk_version( assert actual.rstrip() == expected.rstrip() -def test_agent_engine_app_template_compiles_with_windows_paths() -> None: - """It should not emit invalid Python when paths contain `\\u` segments.""" - rendered = cli_deploy._AGENT_ENGINE_APP_TEMPLATE.format( - is_config_agent=True, - agent_folder=r".\user_agent_tmp20260101_000000", - adk_app_object="root_agent", - adk_app_type="agent", - trace_to_cloud_option=False, - express_mode=False, - extra_imports="", - app_instantiation="agent=root_agent", - ) - compile(rendered, "", "exec") - - def test_print_agent_engine_url() -> None: """It should print the correct URL for a fully-qualified resource name.""" with mock.patch("click.secho") as mocked_secho: @@ -277,8 +263,8 @@ def test_to_agent_engine_happy_path( class _FakeAgentEngines: - def create(self, *, config: Dict[str, Any]) -> Any: - create_recorder(config=config) + def create(self, **kwargs: Any) -> Any: + create_recorder(**kwargs) return types.SimpleNamespace( api_resource=types.SimpleNamespace( name="projects/p/locations/l/reasoningEngines/e" @@ -303,29 +289,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: cli_deploy.to_agent_engine( agent_folder=str(src_dir), temp_folder="tmp", - adk_app="my_adk_app", trace_to_cloud=True, project="my-gcp-project", region="us-central1", display_name="My Test Agent", description="A test agent.", + adk_version="1.2.0", ) - agent_file = tmp_dir / "agent.py" + agent_file = tmp_dir / "Dockerfile" assert agent_file.is_file() - init_file = tmp_dir / "__init__.py" - assert init_file.is_file() - adk_app_file = tmp_dir / "my_adk_app.py" - assert adk_app_file.is_file() - content = adk_app_file.read_text() - assert "from .agent import root_agent" in content - assert "adk_app = AdkApp(" in content - assert "agent=root_agent" in content - assert "enable_tracing=True" in content - reqs_path = tmp_dir / "requirements.txt" - assert reqs_path.is_file() - reqs_content = reqs_path.read_text() - assert "google-cloud-aiplatform[agent_engines]" in reqs_content - assert f"google-adk=={cli_deploy.__version__}" in reqs_content assert len(create_recorder.calls) == 1 assert str(rmtree_recorder.get_last_call_args()[0]) == str(tmp_dir) @@ -352,116 +324,13 @@ def test_to_agent_engine_raises_when_explicit_config_file_missing( display_name="My Test Agent", description="A test agent.", agent_engine_config_file=str(missing_config), + adk_version="1.2.0", ) assert "Agent engine config file not found" in str(exc_info.value) assert expected_abs in str(exc_info.value) -def test_to_agent_engine_skips_agent_import_validation_by_default( - monkeypatch: pytest.MonkeyPatch, - agent_dir: Callable[[bool, bool], Path], -) -> None: - """It should skip agent.py import validation by default.""" - validate_recorder = _Recorder() - - def _validate_agent_import(*args: Any, **kwargs: Any) -> None: - validate_recorder(*args, **kwargs) - raise AssertionError("_validate_agent_import should not be called") - - monkeypatch.setattr( - cli_deploy, "_validate_agent_import", _validate_agent_import - ) - - fake_vertexai = types.ModuleType("vertexai") - - class _FakeAgentEngines: - - def create(self, *, config: Dict[str, Any]) -> Any: - del config - return types.SimpleNamespace( - api_resource=types.SimpleNamespace( - name="projects/p/locations/l/reasoningEngines/e" - ) - ) - - class _FakeVertexClient: - - def __init__(self, *args: Any, **kwargs: Any) -> None: - del args - del kwargs - self.agent_engines = _FakeAgentEngines() - - fake_vertexai.Client = _FakeVertexClient - monkeypatch.setitem(sys.modules, "vertexai", fake_vertexai) - - src_dir = agent_dir(False, False) - cli_deploy.to_agent_engine( - agent_folder=str(src_dir), - temp_folder="tmp", - adk_app="my_adk_app", - trace_to_cloud=True, - project="my-gcp-project", - region="us-central1", - display_name="My Test Agent", - description="A test agent.", - ) - - assert validate_recorder.calls == [] - - -def test_to_agent_engine_validates_agent_import_when_enabled( - monkeypatch: pytest.MonkeyPatch, - agent_dir: Callable[[bool, bool], Path], -) -> None: - """It should run agent.py import validation when enabled.""" - validate_recorder = _Recorder() - - def _validate_agent_import(*args: Any, **kwargs: Any) -> None: - validate_recorder(*args, **kwargs) - - monkeypatch.setattr( - cli_deploy, "_validate_agent_import", _validate_agent_import - ) - - fake_vertexai = types.ModuleType("vertexai") - - class _FakeAgentEngines: - - def create(self, *, config: Dict[str, Any]) -> Any: - del config - return types.SimpleNamespace( - api_resource=types.SimpleNamespace( - name="projects/p/locations/l/reasoningEngines/e" - ) - ) - - class _FakeVertexClient: - - def __init__(self, *args: Any, **kwargs: Any) -> None: - del args - del kwargs - self.agent_engines = _FakeAgentEngines() - - fake_vertexai.Client = _FakeVertexClient - monkeypatch.setitem(sys.modules, "vertexai", fake_vertexai) - - src_dir = agent_dir(False, False) - cli_deploy.to_agent_engine( - agent_folder=str(src_dir), - temp_folder="tmp", - adk_app="my_adk_app", - trace_to_cloud=True, - project="my-gcp-project", - region="us-central1", - display_name="My Test Agent", - description="A test agent.", - skip_agent_import_validation=False, - ) - - assert len(validate_recorder.calls) == 1 - - @pytest.mark.parametrize("include_requirements", [True, False]) def test_to_gke_happy_path( monkeypatch: pytest.MonkeyPatch, @@ -509,7 +378,7 @@ def mock_subprocess_run(*args, **kwargs): dockerfile_path = tmp_path / "Dockerfile" assert dockerfile_path.is_file() dockerfile_content = dockerfile_path.read_text() - assert "CMD adk api_server --with_ui --port=9090" in dockerfile_content + assert "CMD adk web --port=9090" in dockerfile_content assert "RUN pip install google-adk==1.2.0" in dockerfile_content assert len(run_recorder.calls) == 3, "Expected 3 subprocess calls" @@ -743,6 +612,8 @@ def test_to_agent_engine_triggers_onboarding( name="projects/p/locations/l/reasoningEngines/e" ) ) + mock_agent_engines.delete.return_value = None + mock_agent_engines.update.return_value = None monkeypatch.setitem(sys.modules, "vertexai", fake_vertexai) @@ -762,3 +633,51 @@ def test_to_agent_engine_triggers_onboarding( assert kwargs.get("project") == "fake_project" assert kwargs.get("location") == "fake_region" assert "api_key" not in kwargs or kwargs.get("api_key") is None + + +def test_cli_deploy_agent_engine_trigger_sources(tmp_path: Path): + """Tests that --trigger_sources is passed to to_agent_engine.""" + agent_dir = tmp_path / "my_agent" + agent_dir.mkdir() + runner = CliRunner() + with mock.patch( + "google.adk.cli.cli_deploy.to_agent_engine" + ) as mock_to_agent_engine: + result = runner.invoke( + cli_tools_click.main, + [ + "deploy", + "agent_engine", + "--trigger_sources=pubsub,eventarc", + str(agent_dir), + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + mock_to_agent_engine.assert_called_once() + _, kwargs = mock_to_agent_engine.call_args + assert kwargs["trigger_sources"] == "pubsub,eventarc" + + +def test_cli_deploy_agent_engine_artifact_service_uri(tmp_path: Path): + """Tests that --artifact_service_uri is passed to to_agent_engine.""" + agent_dir = tmp_path / "my_agent" + agent_dir.mkdir() + runner = CliRunner() + with mock.patch( + "google.adk.cli.cli_deploy.to_agent_engine" + ) as mock_to_agent_engine: + result = runner.invoke( + cli_tools_click.main, + [ + "deploy", + "agent_engine", + "--artifact_service_uri=gs://my-bucket", + str(agent_dir), + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + mock_to_agent_engine.assert_called_once() + _, kwargs = mock_to_agent_engine.call_args + assert kwargs["artifact_service_uri"] == "gs://my-bucket" diff --git a/tests/unittests/telemetry/test_google_cloud.py b/tests/unittests/telemetry/test_google_cloud.py index 496f093d26..693365de39 100644 --- a/tests/unittests/telemetry/test_google_cloud.py +++ b/tests/unittests/telemetry/test_google_cloud.py @@ -52,18 +52,6 @@ def test_get_gcp_exporters( "google.auth.default", auth_mock, ) - monkeypatch.setattr( - "google.adk.telemetry.google_cloud._get_gcp_span_exporter", - lambda credentials: mock.MagicMock(), - ) - monkeypatch.setattr( - "google.adk.telemetry.google_cloud._get_gcp_metrics_exporter", - lambda project_id: mock.MagicMock(), - ) - monkeypatch.setattr( - "google.adk.telemetry.google_cloud._get_gcp_logs_exporter", - lambda project_id: mock.MagicMock(), - ) # Act. otel_hooks = get_gcp_exporters( @@ -215,4 +203,5 @@ def test_get_gcp_span_exporter_mtls( mock_exporter.assert_called_once_with( session=mock_session.return_value, endpoint=_DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT, + header=None, ) From dc42e4874df5e2629870da0dd7ecd6979e488b6f Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Fri, 22 May 2026 16:52:38 +0000 Subject: [PATCH 02/18] chore: attempted fix to correlate traces with logs --- src/google/adk/cli/cli_deploy.py | 14 +++++------ src/google/adk/cli/cli_tools_click.py | 12 ++++++++++ src/google/adk/cli/fast_api.py | 34 +++++++++++++++++---------- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 8ee0e16b92..db7eae66ec 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -69,11 +69,6 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: FROM python:3.11-slim WORKDIR /app -RUN apt-get update && \ - apt-get upgrade -y && \ - apt-get install -y git && \ - apt -y autoremove - # Create a non-root user RUN adduser --disabled-password --gecos "" myuser @@ -90,7 +85,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: # Set up environment variables - End # Install ADK - Start -# RUN pip install google-adk=={adk_version} +RUN pip install google-adk=={adk_version} # Install ADK - End # Copy agent - Start @@ -106,7 +101,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: EXPOSE {port} -CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} {otel_to_cloud_option} {allow_origins_option} {a2a_option} {trigger_sources_option} {gemini_enterprise_option} "/app/agents" +CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} {otel_to_cloud_option} {allow_origins_option} {a2a_option} {trigger_sources_option} {gemini_enterprise_option}{express_mode_option} "/app/agents" """ _AGENT_ENGINE_CLASS_METHODS = [ @@ -723,6 +718,7 @@ def to_cloud_run( a2a_option=a2a_option, trigger_sources_option=trigger_sources_option, gemini_enterprise_option='', + express_mode_option='', ) dockerfile_path = os.path.join(temp_folder, 'Dockerfile') os.makedirs(temp_folder, exist_ok=True) @@ -1139,6 +1135,9 @@ def create_dockerfile_for_agent_engine(resource_name: str): a2a_option='--a2a', trigger_sources_option=trigger_sources_option, gemini_enterprise_option=f'--gemini_enterprise_app_name={app_name}', + express_mode_option=( + ' --express_mode' if api_key and not project else '' + ), ) with open('Dockerfile', 'w', encoding='utf-8') as f: f.write(dockerfile_content) @@ -1307,6 +1306,7 @@ def to_gke( f'--trigger_sources={trigger_sources}' if trigger_sources else '' ), gemini_enterprise_option='', + express_mode_option='', ) dockerfile_path = os.path.join(temp_folder, 'Dockerfile') os.makedirs(temp_folder, exist_ok=True) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index ea33ede245..4009cf6b02 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1698,6 +1698,16 @@ async def _lifespan(app: FastAPI): " https://docs.cloud.google.com/gemini/enterprise/docs/register-and-manage-an-adk-agent" ), ) +@click.option( + "--express_mode", + is_flag=True, + default=False, + help=( + "Whether or not to initialize the server in express mode. This is only" + " supported when gemini_enterprise_app_name is set. Defaults to" + " False." + ), +) def cli_api_server( agents_dir: str, eval_storage_uri: str | None = None, @@ -1719,6 +1729,7 @@ def cli_api_server( auto_create_session: bool = False, trigger_sources: list[str] | None = None, gemini_enterprise_app_name: str | None = None, + express_mode: bool = False, ): """Starts a FastAPI server for agents. @@ -1752,6 +1763,7 @@ def cli_api_server( auto_create_session=auto_create_session, trigger_sources=trigger_sources, gemini_enterprise_app_name=gemini_enterprise_app_name, + express_mode=express_mode, ), host=host, port=port, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index e13f6ee29d..f85705d342 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -118,6 +118,7 @@ def get_fast_api_app( auto_create_session: bool = False, trigger_sources: list[Literal["pubsub", "eventarc"]] | None = None, gemini_enterprise_app_name: str | None = None, + express_mode: bool = False, ) -> FastAPI: """Constructs and returns a FastAPI application for serving ADK agents. @@ -168,6 +169,9 @@ def get_fast_api_app( event-driven agent invocations. None disables all trigger endpoints. gemini_enterprise_app_name: The app_name to register with Gemini Enterprise via https://docs.cloud.google.com/gemini/enterprise/docs/register-and-manage-an-adk-agent + express_mode: Whether or not to intialize the server in express mode. + This is only supported when gemini_enterprise_app_name is set. Defaults to + False. Returns: The configured FastAPI application instance. @@ -714,26 +718,30 @@ async def _get_a2a_runner_async() -> Runner: import inspect import json - + import google.auth from google.adk.agents import Agent import vertexai from vertexai import agent_engines - project = os.environ.get("GOOGLE_CLOUD_PROJECT", None) - location = os.environ.get( - "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", - os.environ.get("GOOGLE_CLOUD_LOCATION", None), - ) - api_key = os.environ.get("GOOGLE_API_KEY", None) - if project: - vertexai.init(project=project, location=location) - elif api_key: + if express_mode: + api_key = os.environ.get("GOOGLE_API_KEY", None) + if not api_key: + raise ValueError( + "No GOOGLE_API_KEY found in environment variables for express mode." + ) vertexai.init(api_key=api_key) else: - raise ValueError( - "No GOOGLE_CLOUD_PROJECT or GOOGLE_API_KEY found in environment" - " variables." + _, project_id = google.auth.default() + location = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", + os.environ.get("GOOGLE_CLOUD_LOCATION", None), ) + if not project_id or not location: + raise ValueError("No GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_LOCATION found in" + " environment variables." + ) + vertexai.init(project=project_id, location=location) + # The tmp agent will be replaced by the adk server's runner and services. # It is specified here because it is a required argument to AdkApp. adk_app = agent_engines.AdkApp(agent=Agent(name="tmp")) From fa86d2f9eefbd7b57c1726bc03e5c984ffefa1d0 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Fri, 22 May 2026 16:54:48 +0000 Subject: [PATCH 03/18] chore: apply patch for installing git --- src/google/adk/cli/cli_deploy.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index db7eae66ec..1e51bd0201 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -69,6 +69,11 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: FROM python:3.11-slim WORKDIR /app +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y git \ + apt -y autoremove + # Create a non-root user RUN adduser --disabled-password --gecos "" myuser @@ -85,7 +90,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: # Set up environment variables - End # Install ADK - Start -RUN pip install google-adk=={adk_version} +# RUN pip install google-adk=={adk_version} # Install ADK - End # Copy agent - Start From c5ba395f644caf2593963e4f62a2742d2e964ce1 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Fri, 22 May 2026 17:06:55 +0000 Subject: [PATCH 04/18] chore: fix typo --- src/google/adk/cli/cli_deploy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 1e51bd0201..c92969ed8a 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -71,7 +71,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: RUN apt-get update && \ apt-get upgrade -y && \ - apt-get install -y git \ + apt-get install -y git && \ apt -y autoremove # Create a non-root user From 4ad2cf819196ac26af3e0a28256d619422af1b6b Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Fri, 22 May 2026 17:36:23 +0000 Subject: [PATCH 05/18] chore: remove vertexai init --- src/google/adk/cli/fast_api.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index f85705d342..9274ee4765 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -720,31 +720,28 @@ async def _get_a2a_runner_async() -> Runner: import json import google.auth from google.adk.agents import Agent - import vertexai from vertexai import agent_engines + # The tmp agent will be replaced by the adk server's runner and services. + # It is specified here because it is a required argument to AdkApp. + adk_app = agent_engines.AdkApp(agent=Agent(name="tmp")) if express_mode: api_key = os.environ.get("GOOGLE_API_KEY", None) - if not api_key: - raise ValueError( - "No GOOGLE_API_KEY found in environment variables for express mode." - ) - vertexai.init(api_key=api_key) + adk_app._tmpl_attrs["project"] = None + adk_app._tmpl_attrs["location"] = None + adk_app._tmpl_attrs["api_key"] = api_key else: - _, project_id = google.auth.default() + project_id = google.auth.default() location = os.environ.get( "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", os.environ.get("GOOGLE_CLOUD_LOCATION", None), ) - if not project_id or not location: - raise ValueError("No GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_LOCATION found in" - " environment variables." - ) - vertexai.init(project=project_id, location=location) - - # The tmp agent will be replaced by the adk server's runner and services. - # It is specified here because it is a required argument to AdkApp. - adk_app = agent_engines.AdkApp(agent=Agent(name="tmp")) + logging.warning( + "[fast_api] project_id: %s, location: %s", project_id, location + ) + adk_app._tmpl_attrs["project"] = project_id + adk_app._tmpl_attrs["location"] = location + adk_app._tmpl_attrs["api_key"] = None adk_app._tmpl_attrs["runner"] = None adk_app._tmpl_attrs["app_name"] = gemini_enterprise_app_name adk_app._tmpl_attrs["session_service"] = session_service From b72c0f430134473fdb867c97f989ca61889cb916 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Fri, 22 May 2026 18:21:31 +0000 Subject: [PATCH 06/18] chore: try converting project_number to project_id for associating logs to traces --- src/google/adk/cli/fast_api.py | 5 +---- src/google/adk/telemetry/google_cloud.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 9274ee4765..4648e0538b 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -731,14 +731,11 @@ async def _get_a2a_runner_async() -> Runner: adk_app._tmpl_attrs["location"] = None adk_app._tmpl_attrs["api_key"] = api_key else: - project_id = google.auth.default() + _, project_id = google.auth.default() location = os.environ.get( "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", os.environ.get("GOOGLE_CLOUD_LOCATION", None), ) - logging.warning( - "[fast_api] project_id: %s, location: %s", project_id, location - ) adk_app._tmpl_attrs["project"] = project_id adk_app._tmpl_attrs["location"] = location adk_app._tmpl_attrs["api_key"] = None diff --git a/src/google/adk/telemetry/google_cloud.py b/src/google/adk/telemetry/google_cloud.py index 265ef67173..edfc63b367 100644 --- a/src/google/adk/telemetry/google_cloud.py +++ b/src/google/adk/telemetry/google_cloud.py @@ -78,6 +78,18 @@ def get_gcp_exporters( credentials, project_id = ( google_auth if google_auth is not None else google.auth.default() ) + if os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"): + # Try to convert project number to project ID to associate logs with traces. + try: + from google.cloud import resourcemanager + + projects_client = resourcemanager.ProjectsClient(credentials=credentials) + project = projects_client.get_project(name=f"projects/{project_id}") + project_id = project.project_id + except Exception: + logging.warning( + "Failed to convert project number to project ID.", exc_info=True + ) if TYPE_CHECKING: credentials = cast(Credentials, credentials) project_id = cast(str, project_id) From d8f9eca289cd32382e1cc6044b0e41ae7b8908e6 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Mon, 1 Jun 2026 20:56:51 +0000 Subject: [PATCH 07/18] chore: Make cli_deploy.py work in e2e setting --- pyproject.toml | 203 ++++++++++++------ src/google/adk/cli/adk_web_server.py | 14 +- src/google/adk/cli/cli_deploy.py | 12 +- src/google/adk/cli/cli_tools_click.py | 31 ++- src/google/adk/cli/fast_api.py | 57 ++++- src/google/adk/telemetry/google_cloud.py | 89 +++++--- tests/unittests/cli/test_fast_api.py | 136 +----------- .../unittests/telemetry/test_google_cloud.py | 2 +- 8 files changed, 297 insertions(+), 247 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5ba4338903..670d98f7f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,70 +33,85 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ - "aiosqlite>=0.21", # For SQLite database - "anyio>=4.9,<5", # For MCP Session Manager - "authlib>=1.6.6,<2", # For RestAPI Tool - "click>=8.1.8,<9", # For CLI tools - "fastapi>=0.124.1,<1", # FastAPI framework - "google-api-python-client>=2.157,<3", # Google API client discovery - "google-auth[pyopenssl]>=2.47", # Google Auth library - "google-cloud-aiplatform[agent-engines]>=1.148.1,<2", # For VertexAI integrations, e.g. example store. - "google-cloud-bigquery>=2.2", - "google-cloud-bigquery-storage>=2", - "google-cloud-bigtable>=2.32", # For Bigtable database - "google-cloud-dataplex>=1.7,<3", # For Dataplex Catalog Search tool - "google-cloud-discoveryengine>=0.13.12,<0.14", # For Discovery Engine Search Tool - "google-cloud-pubsub>=2,<3", # For Pub/Sub Tool - "google-cloud-secret-manager>=2.22,<3", # Fetching secrets in RestAPI Tool - "google-cloud-spanner>=3.56,<4", # For Spanner database - "google-cloud-speech>=2.30,<3", # For Audio Transcription - "google-cloud-storage>=2.18,<4", # For GCS Artifact service - "google-genai>=1.72,<2", # Google GenAI SDK - "graphviz>=0.20.2,<1", # Graphviz for graph rendering - "httpx>=0.27,<1", # HTTP client library - "jsonschema>=4.23,<5", # Agent Builder config validation - "mcp>=1.24,<2", # For MCP Toolset - "opentelemetry-api>=1.36,<=1.41.1", # OpenTelemetry - allow 1.39+ up to latest published at time of resolution. - "opentelemetry-exporter-gcp-logging>=1.9.0a0,<=1.12.0a0", - "opentelemetry-exporter-gcp-monitoring>=1.9.0a0,<2", - "opentelemetry-exporter-gcp-trace>=1.9,<2", - "opentelemetry-exporter-otlp-proto-http>=1.36", - "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", + "aiosqlite>=0.21", + "authlib>=1.6.6,<2", + "click>=8.1.8,<9", + "fastapi>=0.124.1,<1", + "google-auth[pyopenssl]>=2.47", + "google-genai>=1.72,<2", + "graphviz>=0.20.2,<1", + "httpx>=0.27,<1", + "jsonschema>=4.23,<5", + "opentelemetry-api>=1.36,<=1.41.1", "opentelemetry-sdk>=1.36,<=1.41.1", - "pyarrow>=14", - "pydantic>=2.12,<3", # For data validation/models - "python-dateutil>=2.9.0.post0,<3", # For Vertext AI Session Service - "python-dotenv>=1,<2", # To manage environment variables - "pyyaml>=6.0.2,<7", # For APIHubToolset. + "packaging>=21", + "pydantic>=2.12,<3", + "python-dotenv>=1,<2", + "python-multipart>=0.0.9,<1", + # go/keep-sorted start + "pyyaml>=6.0.2,<7", "requests>=2.32.4,<3", - "sqlalchemy>=2,<3", # SQL database ORM - "sqlalchemy-spanner>=1.14", # Spanner database session service - "starlette>=0.49.1,<1", # For FastAPI CLI - "tenacity>=9,<10", # For Retry management + "starlette>=0.49.1,<1", + "tenacity>=9,<10", "typing-extensions>=4.5,<5", - "tzlocal>=5.3,<6", # Time zone utilities - "uvicorn>=0.34,<1", # ASGI server for FastAPI - "watchdog>=6,<7", # For file change detection and hot reload - "websockets>=15.0.1,<16", # For BaseLlmFlow - "yarl<1.24", + "tzlocal>=5.3,<6", + "uvicorn>=0.34,<1", + "watchdog>=6,<7", + "websockets>=15.0.1,<16", + # go/keep-sorted end ] + optional-dependencies.a2a = [ "a2a-sdk>=0.3.4,<0.4", ] optional-dependencies.agent-identity = [ "google-cloud-iamconnectorcredentials>=0.1,<0.2", ] +optional-dependencies.all = [ + "anyio>=4.9,<5", + "google-api-python-client>=2.157,<3", + "google-cloud-aiplatform[agent-engines]>=1.148.1,<2", + "google-cloud-bigquery>=2.2", + "google-cloud-bigquery-storage>=2", + "google-cloud-bigtable>=2.32", + "google-cloud-dataplex>=1.7,<3", + "google-cloud-discoveryengine>=0.13.12,<0.14", + "google-cloud-pubsub>=2,<3", + "google-cloud-resource-manager>=1.12,<2", + "google-cloud-secret-manager>=2.22,<3", + "google-cloud-spanner>=3.56,<4", + "google-cloud-speech>=2.30,<3", + "google-cloud-storage>=2.18,<4", + "mcp>=1.24,<2", + "opentelemetry-exporter-gcp-logging>=1.9.0a0,<=1.12.0a0", + "opentelemetry-exporter-gcp-monitoring>=1.9.0a0,<2", + "opentelemetry-exporter-gcp-trace>=1.9,<2", + "opentelemetry-exporter-otlp-proto-http>=1.36", + "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", + "pyarrow>=14", + "python-dateutil>=2.9.0.post0,<3", + "sqlalchemy>=2,<3", + "sqlalchemy-spanner>=1.14", +] + optional-dependencies.community = [ "google-adk-community", ] +optional-dependencies.db = [ + "sqlalchemy>=2,<3", + "sqlalchemy-spanner>=1.14", +] + optional-dependencies.dev = [ "flit>=3.10", - "isort>=6", "mypy>=1.15", "pre-commit>=4", "pyink>=25.12", "pylint>=2.6", + "tox>=4.23.2", + "tox-uv>=1.33.2", ] + optional-dependencies.docs = [ "autodoc-pydantic", "furo", @@ -106,6 +121,7 @@ optional-dependencies.docs = [ "sphinx-click", "sphinx-rtd-theme", ] + optional-dependencies.eval = [ "gepa>=0.1", "google-cloud-aiplatform[evaluation]>=1.148", @@ -114,7 +130,7 @@ optional-dependencies.eval = [ "rouge-score>=0.1.2", "tabulate>=0.9", ] -# Optional extensions + optional-dependencies.extensions = [ "anthropic>=0.78", # For anthropic model support; 0.78 introduced ThinkingConfigAdaptiveParam (required for Claude Opus 4.7). "beautifulsoup4>=3.2.2", # For load_web_page tool. @@ -122,16 +138,44 @@ optional-dependencies.extensions = [ "docker>=7", # For ContainerCodeExecutor "google-cloud-firestore>=2.11,<3", # For Firestore services "google-cloud-parametermanager>=0.4,<1", - "k8s-agent-sandbox>=0.1.1.post3", # For GkeCodeExecutor sandbox mode - "kubernetes>=29", # For GkeCodeExecutor - "langgraph>=0.2.60,<0.4.8", # For LangGraphAgent - "litellm>=1.83.7,<=1.83.14", # For LiteLlm class. Lower bound: 5 CVE patches (2026-04). Upper bound pinned to current latest; bump deliberately. See #5488. - "llama-index-embeddings-google-genai>=0.3", # For files retrieval using LlamaIndex. - "llama-index-readers-file>=0.4", # For retrieval using LlamaIndex. - "lxml>=5.3", # For load_web_page tool. - "pypika>=0.50", # For crewai->chromadb dependency - "toolbox-adk>=1,<2", # For tools.toolbox_toolset.ToolboxToolset + "k8s-agent-sandbox>=0.1.1.post3", + "kubernetes>=29", + "langgraph>=0.2.60,<0.4.8", + "litellm>=1.83.7,<=1.83.14", + "llama-index-embeddings-google-genai>=0.3", + "llama-index-readers-file>=0.4", + "lxml>=5.3", + "pypika>=0.50", + "toolbox-adk>=1,<2", ] + +optional-dependencies.gcp = [ + "google-cloud-aiplatform[agent-engines]>=1.148.1,<2", + "google-cloud-bigquery>=2.2", + "google-cloud-bigquery-storage>=2", + "google-cloud-bigtable>=2.32", + "google-cloud-dataplex>=1.7,<3", + "google-cloud-discoveryengine>=0.13.12,<0.14", + "google-cloud-pubsub>=2,<3", + "google-cloud-resource-manager>=1.12,<2", + "google-cloud-secret-manager>=2.22,<3", + "google-cloud-spanner>=3.56,<4", + "google-cloud-speech>=2.30,<3", + "google-cloud-storage>=2.18,<4", + "opentelemetry-exporter-gcp-logging>=1.9.0a0,<=1.12.0a0", + "opentelemetry-exporter-gcp-monitoring>=1.9.0a0,<2", + "opentelemetry-exporter-gcp-trace>=1.9,<2", + "opentelemetry-exporter-otlp-proto-http>=1.36", + "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", + "pyarrow>=14", + "python-dateutil>=2.9.0.post0,<3", +] + +optional-dependencies.mcp = [ + "anyio>=4.9,<5", + "mcp>=1.24,<2", +] + optional-dependencies.otel-gcp = [ "opentelemetry-instrumentation-google-genai>=0.6b0,<1", "opentelemetry-instrumentation-grpc>=0.43b0,<1", @@ -141,28 +185,62 @@ optional-dependencies.slack = [ "slack-bolt>=1.22" ] optional-dependencies.test = [ "a2a-sdk>=0.3,<0.4", "anthropic>=0.78", # For anthropic model tests; 0.78 introduced ThinkingConfigAdaptiveParam (required for Claude Opus 4.7). + "anyio>=4.9,<5", "crewai[tools]; python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ + "gepa>=0.1", + "google-api-python-client>=2.157,<3", + "google-cloud-aiplatform[agent-engines,evaluation]>=1.148.1,<2", + "google-cloud-bigquery>=2.2", + "google-cloud-bigquery-storage>=2", + "google-cloud-bigtable>=2.32", + "google-cloud-dataplex>=1.7,<3", + "google-cloud-discoveryengine>=0.13.12,<0.14", "google-cloud-firestore>=2.11,<3", "google-cloud-iamconnectorcredentials>=0.1,<0.2", "google-cloud-parametermanager>=0.4,<1", - "kubernetes>=29", # For GkeCodeExecutor + "google-cloud-pubsub>=2,<3", + "google-cloud-resource-manager>=1.12,<2", + "google-cloud-secret-manager>=2.22,<3", + "google-cloud-spanner>=3.56,<4", + "google-cloud-speech>=2.30,<3", + "google-cloud-storage>=2.18,<4", + "jinja2>=3.1.4,<4", + "kubernetes>=29", "langchain-community>=0.3.17", - "langgraph>=0.2.60,<0.4.8", # For LangGraphAgent - "litellm>=1.83.7,<=1.83.14", # For LiteLLM tests. Lower bound: 5 CVE patches (2026-04). Upper bound pinned to current latest; bump deliberately. See #5488. - "llama-index-readers-file>=0.4", # For retrieval tests - "openai>=1.100.2", # For LiteLLM + "langgraph>=0.2.60,<0.4.8", + "litellm>=1.83.7,<=1.83.14", + "llama-index-readers-file>=0.4", + "mcp>=1.24,<2", + "openai>=1.100.2", + "opentelemetry-exporter-gcp-logging>=1.9.0a0,<=1.12.0a0", + "opentelemetry-exporter-gcp-monitoring>=1.9.0a0,<2", + "opentelemetry-exporter-gcp-trace>=1.9,<2", + "opentelemetry-exporter-otlp-proto-http>=1.36", "opentelemetry-instrumentation-google-genai>=0.3b0,<1", - "pypika>=0.50", # For crewai->chromadb dependency + "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", + "pandas>=2.2.3", + "pyarrow>=14", + "pypika>=0.50", "pytest>=9,<10", "pytest-asyncio>=0.25", "pytest-mock>=3.14", "pytest-xdist>=3.6.1", + "python-dateutil>=2.9.0.post0,<3", "python-multipart>=0.0.9", "rouge-score>=0.1.2", "slack-bolt>=1.22", + "sqlalchemy>=2,<3", + "sqlalchemy-spanner>=1.14", "tabulate>=0.9", + "tomli>=2,<3; python_version<'3.11'", ] + optional-dependencies.toolbox = [ "toolbox-adk>=1,<2" ] + +optional-dependencies.tools = [ + "google-api-python-client>=2.157,<3", +] + urls.changelog = "https://github.com/google/adk-python/blob/main/CHANGELOG.md" urls.documentation = "https://google.github.io/adk-docs/" urls.homepage = "https://google.github.io/adk-docs/" @@ -180,7 +258,7 @@ include = [ "py.typed" ] [tool.isort] profile = "google" single_line_exclusions = [ ] -line_length = 200 # Prevent line wrap flickering. +line_length = 200 known_third_party = [ "google.adk", "a2a" ] [tool.pytest.ini_options] @@ -189,7 +267,7 @@ asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" [tool.mypy] -python_version = "3.10" +python_version = "3.11" exclude = [ "tests/", "contributing/samples/" ] plugins = [ "pydantic.mypy" ] strict = true @@ -197,7 +275,6 @@ disable_error_code = [ "import-not-found", "import-untyped", "unused-ignore" ] follow_imports = "skip" [tool.pyink] -# Format py files following Google style-guide line-length = 80 unstable = true pyink-indentation = 2 diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index bced893407..5b4cdc2233 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -369,6 +369,7 @@ class RunAgentRequest(common.BaseModel): function_call_event_id: Optional[str] = None # for resume long-running functions invocation_id: Optional[str] = None + custom_metadata: Optional[dict[str, Any]] = None class CreateSessionRequest(common.BaseModel): @@ -862,6 +863,7 @@ def _setup_runtime_config(self, web_assets_dir: str): os.makedirs(os.path.dirname(runtime_config_path), exist_ok=True) with open(runtime_config_path, "w") as f: json.dump(runtime_config, f, indent=2) + f.write("\n") except IOError as e: logger.error( "Failed to write runtime config file %s: %s", runtime_config_path, e @@ -1910,6 +1912,11 @@ async def run_agent(req: RunAgentRequest) -> list[Event]: self.current_app_name_ref.value = req.app_name runner = await self.get_runner_async(req.app_name) _set_telemetry_context_if_needed(runner) + run_config = ( + RunConfig(custom_metadata=req.custom_metadata) + if req.custom_metadata + else None + ) try: async with Aclosing( runner.run_async( @@ -1918,6 +1925,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]: new_message=req.new_message, state_delta=req.state_delta, invocation_id=req.invocation_id, + run_config=run_config, ) ) as agen: events = [event async for event in agen] @@ -1954,13 +1962,17 @@ async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: # Convert the events to properly formatted SSE async def event_generator(): + run_config = RunConfig( + streaming_mode=stream_mode, + custom_metadata=req.custom_metadata, + ) async with Aclosing( runner.run_async( user_id=req.user_id, session_id=req.session_id, new_message=req.new_message, state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), + run_config=run_config, invocation_id=req.invocation_id, ) ) as agen: diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index c92969ed8a..5f3ccb8148 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -831,6 +831,8 @@ def to_agent_engine( agent_engine_config_file: Optional[str] = None, skip_agent_import_validation: bool = True, trigger_sources: Optional[str] = None, + memory_service_uri: Optional[str] = None, + session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, adk_version: Optional[str] = None, ): @@ -887,6 +889,12 @@ def to_agent_engine( trigger_sources (str): Optional. Comma-separated list of trigger sources to enable (e.g., 'pubsub,eventarc'). Registers /trigger/* endpoints for batch and event-driven agent invocations. + memory_service_uri (str): Optional. The URI of the memory service. If not + specified, the memory service will be deployed to the same parent resource + as the runtime. + session_service_uri (str): Optional. The URI of the session service. If not + specified, the session service will be deployed to the same parent + resource as the runtime. artifact_service_uri (str): Optional. The URI of the artifact service. adk_version (str): Optional. The ADK version to use in Agent Engine deployment. If not specified, the version in the dev environment will be @@ -1127,9 +1135,9 @@ def create_dockerfile_for_agent_engine(resource_name: str): install_agent_deps=install_agent_deps, service_option=_get_service_option_by_adk_version( adk_version, - agent_engine_uri, # session_service_uri + session_service_uri or agent_engine_uri, artifact_service_uri, - agent_engine_uri, # memory_service_uri + memory_service_uri or agent_engine_uri, False, # use_local_storage ), trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '', diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 4009cf6b02..5dd5c2a5bc 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -571,8 +571,11 @@ def decorator(func): "--memory_service_uri", type=str, help=textwrap.dedent("""\ - \b Optional. The URI of the memory service. + If set, ADK uses this service. + + \b + If unset, ADK chooses a default memory service. - Use 'rag://' to connect to Vertex AI Rag Memory Service. - Use 'agentengine://' to connect to Agent Engine sessions. can either be the full qualified resource @@ -1740,6 +1743,12 @@ def cli_api_server( adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir """ + if express_mode and not gemini_enterprise_app_name: + raise click.UsageError( + "--express_mode is only supported when --gemini_enterprise_app_name is" + " set." + ) + logs.setup_adk_logger(getattr(logging, log_level.upper())) config = uvicorn.Config( @@ -2230,21 +2239,7 @@ def cli_migrate_session( " the version in the dev environment)" ), ) -@click.option( - "--artifact_service_uri", - type=str, - help=textwrap.dedent( - """\ - Optional. The URI of the artifact service. If set, ADK uses this service. - - \b - If unset, ADK chooses a default artifact service. - - Use 'gs://' to connect to the GCS artifact service. - - Use 'memory://' to force the in-memory artifact service. - - Use 'file://' to store artifacts in a custom local directory.""" - ), - default=None, -) +@adk_services_options(default_use_local_storage=False) @click.argument( "agent", type=click.Path( @@ -2274,6 +2269,8 @@ def cli_deploy_agent_engine( adk_version: str | None = None, trigger_sources: str | None = None, artifact_service_uri: str | None = None, + memory_service_uri: str | None = None, + session_service_uri: str | None = None, ): """Deploys an agent to Agent Engine. @@ -2315,6 +2312,8 @@ def cli_deploy_agent_engine( skip_agent_import_validation=not validate_agent_import, trigger_sources=trigger_sources, artifact_service_uri=artifact_service_uri, + memory_service_uri=memory_service_uri, + session_service_uri=session_service_uri, adk_version=adk_version, ) except Exception as e: diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4648e0538b..37aafbee1e 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -51,10 +51,11 @@ from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..runners import Runner -from .adk_web_server import AdkWebServer -from .service_registry import load_services_module from ..telemetry._agent_engine import get_propagated_context from ..telemetry._agent_engine import TopSpanProcessor +from .adk_web_server import AdkWebServer +from .cli_deploy import _AGENT_ENGINE_CLASS_METHODS +from .service_registry import load_services_module from .utils import envs from .utils import evals from .utils.agent_change_handler import AgentChangeEventHandler @@ -65,6 +66,10 @@ from .utils.service_factory import create_memory_service_from_options from .utils.service_factory import create_session_service_from_options +_ALLOWED_AGENT_ENGINE_CLASS_METHODS = frozenset( + method["name"] for method in _AGENT_ENGINE_CLASS_METHODS +) + class _QueryRequest(BaseModel): input: dict[str, Any] | None = None @@ -169,7 +174,7 @@ def get_fast_api_app( event-driven agent invocations. None disables all trigger endpoints. gemini_enterprise_app_name: The app_name to register with Gemini Enterprise via https://docs.cloud.google.com/gemini/enterprise/docs/register-and-manage-an-adk-agent - express_mode: Whether or not to intialize the server in express mode. + express_mode: Whether or not to initialize the server in express mode. This is only supported when gemini_enterprise_app_name is set. Defaults to False. @@ -177,6 +182,11 @@ def get_fast_api_app( The configured FastAPI application instance. """ + if express_mode and not gemini_enterprise_app_name: + raise ValueError( + "express_mode is only supported when gemini_enterprise_app_name is set." + ) + # Enable denylist enforcement for config loads if web UI is enabled. if web: from ..agents import config_agent_utils @@ -252,6 +262,25 @@ def get_fast_api_app( # Callbacks & other optional args for when constructing the FastAPI instance extra_fast_api_args = {} + # Synchronize otel_to_cloud and GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY. + # This is to support toggling telemetry in the Agent Platform Console, which + # sets the environment variable GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY. + if otel_to_cloud: + os.environ["GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"] = "true" + logger.info( + "Setting GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY to true based on" + " otel_to_cloud flag." + ) + elif gemini_enterprise_app_name and os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" + ): + logger.info( + "Setting otel_to_cloud to True for Gemini Enterprise app %s based on" + " GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY environment variable.", + gemini_enterprise_app_name, + ) + otel_to_cloud = True + # TODO - Remove separate trace_to_cloud logic once otel_to_cloud stops being # EXPERIMENTAL. if trace_to_cloud and not otel_to_cloud: @@ -718,8 +747,9 @@ async def _get_a2a_runner_async() -> Runner: import inspect import json - import google.auth + from google.adk.agents import Agent + import google.auth from vertexai import agent_engines # The tmp agent will be replaced by the adk server's runner and services. @@ -727,6 +757,10 @@ async def _get_a2a_runner_async() -> Runner: adk_app = agent_engines.AdkApp(agent=Agent(name="tmp")) if express_mode: api_key = os.environ.get("GOOGLE_API_KEY", None) + if not api_key: + raise ValueError( + "No GOOGLE_API_KEY found in environment variables for express mode." + ) adk_app._tmpl_attrs["project"] = None adk_app._tmpl_attrs["location"] = None adk_app._tmpl_attrs["api_key"] = api_key @@ -736,6 +770,11 @@ async def _get_a2a_runner_async() -> Runner: "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", os.environ.get("GOOGLE_CLOUD_LOCATION", None), ) + if not project_id or not location: + raise ValueError( + "No GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_LOCATION found in" + " environment variables." + ) adk_app._tmpl_attrs["project"] = project_id adk_app._tmpl_attrs["location"] = location adk_app._tmpl_attrs["api_key"] = None @@ -813,6 +852,11 @@ async def query(request: _QueryRequest): raise HTTPException( status_code=400, detail="class_method cannot be None" ) + if request.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS: + raise HTTPException( + status_code=400, + detail=f"class_method {request.class_method} is not allowed", + ) method = getattr(adk_app, request.class_method) output = await _invoke_callable_or_raise(method, request.input or {}) @@ -843,6 +887,11 @@ async def stream_query(request: _QueryRequest): raise HTTPException( status_code=400, detail="class_method cannot be None" ) + if request.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS: + raise HTTPException( + status_code=400, + detail=f"class_method {request.class_method} is not allowed", + ) method = getattr(adk_app, request.class_method) output = await _invoke_callable_or_raise(method, request.input or {}) diff --git a/src/google/adk/telemetry/google_cloud.py b/src/google/adk/telemetry/google_cloud.py index edfc63b367..c936105cc7 100644 --- a/src/google/adk/telemetry/google_cloud.py +++ b/src/google/adk/telemetry/google_cloud.py @@ -43,21 +43,21 @@ logger = logging.getLogger("google_adk." + __name__) -_GCP_LOG_NAME_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_DEFAULT_LOG_NAME' -_DEFAULT_LOG_NAME = 'adk-otel' +_GCP_LOG_NAME_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_DEFAULT_LOG_NAME" +_DEFAULT_LOG_NAME = "adk-otel" -_DEFAULT_TELEMETRY_TRACES_ENPOINT = 'https://telemetry.googleapis.com/v1/traces' +_DEFAULT_TELEMETRY_TRACES_ENPOINT = "https://telemetry.googleapis.com/v1/traces" _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT = ( - 'https://telemetry.mtls.googleapis.com/v1/traces' + "https://telemetry.mtls.googleapis.com/v1/traces" ) class _MtlsEndpoint(enum.Enum): """The mTLS endpoint setting.""" - AUTO = 'auto' - ALWAYS = 'always' - NEVER = 'never' + AUTO = "auto" + ALWAYS = "always" + NEVER = "never" def get_gcp_exporters( @@ -88,7 +88,10 @@ def get_gcp_exporters( project_id = project.project_id except Exception: logging.warning( - "Failed to convert project number to project ID.", exc_info=True + "Failed to convert project number to project ID. Your traces and logs" + " may not be associated. To fix this, consider enabling the resource" + " manager API and redeploying your agent.", + exc_info=True, ) if TYPE_CHECKING: credentials = cast(Credentials, credentials) @@ -151,6 +154,7 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor: headers = None if os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"): from google.cloud.aiplatform import version as aip_version + try: from opentelemetry.exporter.otlp.proto.http import version as otlp_http_version except (ImportError, AttributeError): @@ -158,7 +162,9 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor: user_agent = f"Vertex-Agent-Engine/{aip_version.__version__}" if otlp_http_version: - user_agent += f" OTel-OTLP-Exporter-Python/{otlp_http_version.__version__}" + user_agent += ( + f" OTel-OTLP-Exporter-Python/{otlp_http_version.__version__}" + ) headers = {"User-Agent": user_agent} return BatchSpanProcessor( @@ -201,6 +207,20 @@ def _get_gcp_logs_exporter( ) +def _detect_cloud_resource_id(project_id: str) -> Optional[str]: + """Detects the cloud resource ID.""" + location = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION") or os.getenv( + "GOOGLE_CLOUD_LOCATION" + ) + agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID") + if project_id and location and agent_engine_id: + return ( + f"//aiplatform.googleapis.com/projects/{project_id}" + f"/locations/{location}/reasoningEngines/{agent_engine_id}" + ) + return None + + def get_gcp_resource(project_id: Optional[str] = None) -> Resource: """Returns OTEL with attributes specified in the following order (attributes specified later, overwrite those specified earlier): 1. Populates gcp.project_id attribute from the project_id argument if present. @@ -212,23 +232,28 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource: This may be overwritten by OTELResourceDetector, if `gcp.project_id` is present in `OTEL_RESOURCE_ATTRIBUTES` env var. """ agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "") + cloud_resource_id = _detect_cloud_resource_id(project_id=project_id) + resource_attributes = { + "gcp.project_id": project_id, + "cloud.account.id": project_id, + "cloud.provider": "gcp", + "cloud.platform": "gcp.agent_engine", + "service.name": agent_engine_id, + "service.version": os.getenv( + "GOOGLE_CLOUD_AGENT_ENGINE_RUNTIME_REVISION_ID", "" + ), + "service.instance.id": f"{uuid.uuid4().hex}-{os.getpid()}", + "cloud.region": ( + os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", "") + or os.getenv("GOOGLE_CLOUD_LOCATION", "") + ), + } + if cloud_resource_id is not None: + resource_attributes["cloud.resource.id"] = cloud_resource_id + if agent_engine_id: resource = Resource.create( - attributes={ - "gcp.project_id": project_id, - "cloud.account.id": project_id, - "cloud.provider": "gcp", - "cloud.platform": "gcp.agent_engine", - "service.name": agent_engine_id, - "service.version": os.getenv( - "GOOGLE_CLOUD_AGENT_ENGINE_RUNTIME_REVISION_ID", "" - ), - "service.instance.id": f"{uuid.uuid4().hex}-{os.getpid()}", - "cloud.region": ( - os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", "") - or os.getenv("GOOGLE_CLOUD_LOCATION", "") - ), - } + attributes=resource_attributes ).merge(OTELResourceDetector().detect()) return resource @@ -265,15 +290,15 @@ def _get_api_endpoint( str: The API endpoint to be used. """ use_mtls_endpoint_str = os.getenv( - 'GOOGLE_API_USE_MTLS_ENDPOINT', _MtlsEndpoint.AUTO.value + "GOOGLE_API_USE_MTLS_ENDPOINT", _MtlsEndpoint.AUTO.value ).lower() try: use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str) except ValueError: logger.warning( - 'Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of ' - '%s. Defaulting to %s.', + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of " + "%s. Defaulting to %s.", [e.value for e in _MtlsEndpoint], _MtlsEndpoint.AUTO.value, ) @@ -301,14 +326,14 @@ def _use_client_cert_effective() -> bool: return bool(mtls.should_use_client_cert()) except (ImportError, AttributeError): use_client_cert_str = os.getenv( - 'GOOGLE_API_USE_CLIENT_CERTIFICATE', 'false' + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" ).lower() - if use_client_cert_str not in ('true', 'false'): + if use_client_cert_str not in ("true", "false"): logger.warning( - 'Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be' - ' either `true` or `false`' + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" ) - return use_client_cert_str == 'true' + return use_client_cert_str == "true" def _get_agent_engine_logs_exporter( diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 023d3c78d0..b7911dc1f1 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -911,17 +911,17 @@ def test_app_with_gemini_enterprise( mock_adk_app_instance = MagicMock() mock_adk_app_instance._tmpl_attrs = {} - async def my_method_impl(**kwargs): + async def get_session_impl(**kwargs): return {"result": "success", "kwargs": kwargs} - mock_adk_app_instance.my_method = my_method_impl + mock_adk_app_instance.get_session = get_session_impl - async def my_stream_method_impl(**kwargs): + async def stream_query_impl(**kwargs): yield {"chunk": 1, "kwargs": kwargs} await asyncio.sleep(0) yield {"chunk": 2, "kwargs": kwargs} - mock_adk_app_instance.my_stream_method = my_stream_method_impl + mock_adk_app_instance.stream_query = stream_query_impl with ( patch("vertexai.init", new_callable=MagicMock) as mock_vertexai_init, @@ -930,11 +930,11 @@ async def my_stream_method_impl(**kwargs): ) as mock_adk_app_cls, patch("google.adk.agents.Agent", new_callable=MagicMock), patch( - "google.adk.cli.utils._telemetry.TopSpanProcessor", + "google.adk.telemetry._agent_engine.TopSpanProcessor", new_callable=MagicMock, ), patch( - "google.adk.cli.utils._telemetry.get_propagated_context", + "google.adk.telemetry._agent_engine.get_propagated_context", new_callable=MagicMock, ), ): @@ -2769,131 +2769,11 @@ def test_gemini_app_not_found_raises( ) -def test_gemini_missing_credentials_raises( - mock_session_service, - mock_artifact_service, - mock_memory_service, - mock_agent_loader, - mock_eval_sets_manager, - mock_eval_set_results_manager, - monkeypatch, -): - """Test get_fast_api_app raises ValueError if no credentials are provided.""" - monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) - monkeypatch.delenv("GOOGLE_API_KEY", raising=False) - mock_agent_loader.list_agents = MagicMock(return_value=["gemini_app"]) - with pytest.raises( - ValueError, match="No GOOGLE_CLOUD_PROJECT or GOOGLE_API_KEY" - ): - with ( - patch("vertexai.init"), - patch("vertexai.agent_engines.AdkApp"), - patch("google.adk.agents.Agent"), - patch( - "google.adk.cli.utils._telemetry.TopSpanProcessor", - new_callable=MagicMock, - ), - patch( - "google.adk.cli.utils._telemetry.get_propagated_context", - new_callable=MagicMock, - ), - ): - _create_test_client( - mock_session_service, - mock_artifact_service, - mock_memory_service, - mock_agent_loader, - mock_eval_sets_manager, - mock_eval_set_results_manager, - gemini_enterprise_app_name="gemini_app", - ) - - -def test_gemini_init_with_project_id( - mock_session_service, - mock_artifact_service, - mock_memory_service, - mock_agent_loader, - mock_eval_sets_manager, - mock_eval_set_results_manager, - monkeypatch, -): - """Test vertexai.init is called with project_id.""" - monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") - monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "test-location") - monkeypatch.delenv("GOOGLE_API_KEY", raising=False) - mock_agent_loader.list_agents = MagicMock(return_value=["gemini_app"]) - with ( - patch("vertexai.init") as mock_init, - patch("vertexai.agent_engines.AdkApp"), - patch("google.adk.agents.Agent"), - patch( - "google.adk.cli.utils._telemetry.TopSpanProcessor", - new_callable=MagicMock, - ), - patch( - "google.adk.cli.utils._telemetry.get_propagated_context", - new_callable=MagicMock, - ), - ): - _create_test_client( - mock_session_service, - mock_artifact_service, - mock_memory_service, - mock_agent_loader, - mock_eval_sets_manager, - mock_eval_set_results_manager, - gemini_enterprise_app_name="gemini_app", - ) - mock_init.assert_called_once_with( - project="test-project", - location="test-location", - ) - - -def test_gemini_init_with_api_key( - mock_session_service, - mock_artifact_service, - mock_memory_service, - mock_agent_loader, - mock_eval_sets_manager, - mock_eval_set_results_manager, - monkeypatch, -): - """Test vertexai.init is called with api_key.""" - monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - mock_agent_loader.list_agents = MagicMock(return_value=["gemini_app"]) - with ( - patch("vertexai.init") as mock_init, - patch("vertexai.agent_engines.AdkApp"), - patch("google.adk.agents.Agent"), - patch( - "google.adk.cli.utils._telemetry.TopSpanProcessor", - new_callable=MagicMock, - ), - patch( - "google.adk.cli.utils._telemetry.get_propagated_context", - new_callable=MagicMock, - ), - ): - _create_test_client( - mock_session_service, - mock_artifact_service, - mock_memory_service, - mock_agent_loader, - mock_eval_sets_manager, - mock_eval_set_results_manager, - gemini_enterprise_app_name="gemini_app", - ) - mock_init.assert_called_once_with(api_key="test-api-key") - - def test_gemini_reasoning_engine_success(test_app_with_gemini_enterprise): """Test POST /api/reasoning_engine success case.""" response = test_app_with_gemini_enterprise.post( "/api/reasoning_engine", - json={"class_method": "my_method", "input": {"arg1": 1}}, + json={"class_method": "get_session", "input": {"arg1": 1}}, ) assert response.status_code == 200 assert response.json() == { @@ -2918,7 +2798,7 @@ def test_gemini_stream_reasoning_engine_success( """Test POST /api/stream_reasoning_engine success case.""" response = test_app_with_gemini_enterprise.post( "/api/stream_reasoning_engine", - json={"class_method": "my_stream_method", "input": {"arg1": 1}}, + json={"class_method": "stream_query", "input": {"arg1": 1}}, ) assert response.status_code == 200 lines = response.text.strip().split("\n") diff --git a/tests/unittests/telemetry/test_google_cloud.py b/tests/unittests/telemetry/test_google_cloud.py index 693365de39..c75397449d 100644 --- a/tests/unittests/telemetry/test_google_cloud.py +++ b/tests/unittests/telemetry/test_google_cloud.py @@ -203,5 +203,5 @@ def test_get_gcp_span_exporter_mtls( mock_exporter.assert_called_once_with( session=mock_session.return_value, endpoint=_DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT, - header=None, + headers=None, ) From 9b1a4bb838ee85d7fc4ad96ddbfb993dadd1dc24 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Mon, 1 Jun 2026 21:15:30 +0000 Subject: [PATCH 08/18] chore: add required keyword argument --- src/google/adk/cli/cli_tools_click.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 5dd5c2a5bc..4307866279 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -2271,6 +2271,7 @@ def cli_deploy_agent_engine( artifact_service_uri: str | None = None, memory_service_uri: str | None = None, session_service_uri: str | None = None, + use_local_storage: bool = False, ): """Deploys an agent to Agent Engine. From 1432950275bba0ed8fa85fa49ae56765fba6d74c Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Tue, 2 Jun 2026 21:17:00 +0000 Subject: [PATCH 09/18] chore: Update Dockerfile for e2e testing --- pyproject.toml | 6 +- src/google/adk/cli/api_server.py | 100 ++- src/google/adk/cli/cli_deploy.py | 17 +- src/google/adk/cli/cli_tools_click.py | 264 ++++++- src/google/adk/cli/fast_api.py | 695 +++++++++--------- tests/unittests/cli/test_fast_api.py | 554 +++++++++++--- tests/unittests/cli/utils/test_cli_deploy.py | 4 +- .../unittests/telemetry/test_google_cloud.py | 12 + 8 files changed, 1164 insertions(+), 488 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 670d98f7f1..fe33e65229 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,9 @@ dependencies = [ "aiosqlite>=0.21", "authlib>=1.6.6,<2", "click>=8.1.8,<9", - "fastapi>=0.124.1,<1", + "fastapi>=0.133,<1", "google-auth[pyopenssl]>=2.47", - "google-genai>=1.72,<2", + "google-genai>=2.4,<3", "graphviz>=0.20.2,<1", "httpx>=0.27,<1", "jsonschema>=4.23,<5", @@ -51,7 +51,7 @@ dependencies = [ # go/keep-sorted start "pyyaml>=6.0.2,<7", "requests>=2.32.4,<3", - "starlette>=0.49.1,<1", + "starlette>=1.0.1,<2", "tenacity>=9,<10", "typing-extensions>=4.5,<5", "tzlocal>=5.3,<6", diff --git a/src/google/adk/cli/api_server.py b/src/google/adk/cli/api_server.py index 3392ecb64f..faf67e47db 100644 --- a/src/google/adk/cli/api_server.py +++ b/src/google/adk/cli/api_server.py @@ -26,7 +26,6 @@ import os import re import sys -import time import traceback import typing from typing import Any @@ -38,6 +37,7 @@ from fastapi import FastAPI from fastapi import HTTPException from fastapi import Query +from fastapi import Request from fastapi import Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse @@ -384,6 +384,7 @@ class RunAgentRequest(common.BaseModel): function_call_event_id: Optional[str] = None # for resume long-running functions invocation_id: Optional[str] = None + custom_metadata: Optional[dict[str, Any]] = None class CreateSessionRequest(common.BaseModel): @@ -546,6 +547,26 @@ def _setup_instrumentation_lib_if_installed(): "Unable to import GoogleGenAiSdkInstrumentor - some" " telemetry will be disabled. Make sure to install google-adk[otel-gcp]" ) + if os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID"): + # Set up HTTPX and gRPC instrumentation for A2A multi-agent observability. + try: + from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor + + HTTPXClientInstrumentor().instrument() + except (ImportError, AttributeError): + logger.warning( + "telemetry enabled but proceeding without HTTPX instrumentation," + " because google-adk[otel-gcp] has not been installed" + ) + try: + from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient + + GrpcInstrumentorClient().instrument() + except (ImportError, AttributeError): + logger.warning( + "telemetry enabled but proceeding without gRPC instrumentation," + " because google-adk[otel-gcp] has not been installed" + ) class ApiServer: @@ -803,6 +824,7 @@ def _setup_runtime_config(self, web_assets_dir: str): os.makedirs(os.path.dirname(runtime_config_path), exist_ok=True) with open(runtime_config_path, "w") as f: json.dump(runtime_config, f, indent=2) + f.write("\n") except IOError as e: logger.error( "Failed to write runtime config file %s: %s", runtime_config_path, e @@ -1043,7 +1065,7 @@ async def get_adk_app_info(app_name: str) -> AppInfo: root_agent_name=root_agent.name, description=root_agent.description, language="python", - agents=get_agents_dict(root_agent), + agents=await get_agents_dict(root_agent), ) else: raise HTTPException( @@ -1404,7 +1426,7 @@ def _set_telemetry_context_if_needed(runner: Runner): _is_visual_builder.set(False) @app.post("/run", response_model_exclude_none=True) - async def run_agent(req: RunAgentRequest) -> list[Event]: + async def run_agent(req: RunAgentRequest, request: Request) -> list[Event]: app_name = req.app_name or self.default_app_name if not app_name: raise HTTPException( @@ -1415,22 +1437,59 @@ async def run_agent(req: RunAgentRequest) -> list[Event]: self.current_app_name_ref.value = req.app_name runner = await self.get_runner_async(req.app_name) _set_telemetry_context_if_needed(runner) + run_config = ( + RunConfig(custom_metadata=req.custom_metadata) + if req.custom_metadata + else None + ) + + async def worker(): + try: + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + invocation_id=req.invocation_id, + run_config=run_config, + ) + ) as agen: + return [event async for event in agen] + except SessionNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + worker_task = asyncio.create_task(worker()) + + async def monitor(): + try: + while True: + message = await request.receive() + if message.get("type") == "http.disconnect": + logger.warning( + "Client disconnected. Aborting agent run for session %s.", + req.session_id, + ) + worker_task.cancel() + break + except asyncio.CancelledError: + pass + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Exception in disconnect monitor: %s", e, exc_info=True) + + monitor_task = asyncio.create_task(monitor()) + try: - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - invocation_id=req.invocation_id, - ) - ) as agen: - events = [event async for event in agen] - except SessionNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - logger.info("Generated %s events in agent run", len(events)) - logger.debug("Events generated: %s", events) - return events + events = await worker_task + logger.info("Generated %s events in agent run", len(events)) + logger.debug("Events generated: %s", events) + return events + except asyncio.CancelledError: + if await request.is_disconnected(): + return Response(status_code=499) + raise + finally: + monitor_task.cancel() @app.post("/run_sse") async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: @@ -1472,7 +1531,10 @@ async def event_generator(): session_id=req.session_id, new_message=req.new_message, state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), + run_config=RunConfig( + streaming_mode=stream_mode, + custom_metadata=req.custom_metadata, + ), invocation_id=req.invocation_id, ) ) as agen: diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 5f3ccb8148..4927be385d 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -29,6 +29,7 @@ import click from packaging.version import parse +from ..version import __version__ from .utils import _onboarding _IS_WINDOWS = os.name == 'nt' @@ -62,7 +63,8 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: with open(requirements_txt_path, 'a', encoding='utf-8') as f: if requirements and not requirements.endswith('\n'): f.write('\n') - f.write(_AGENT_ENGINE_REQUIREMENT + '\n') + f.write('google-cloud-aiplatform[agent_engines]\n') + f.write(f'google-adk=={__version__}\n') _DOCKERFILE_TEMPLATE: Final[str] = """ @@ -706,7 +708,7 @@ def to_cloud_run( gcp_region=region, app_name=app_name, port=port, - command='web' if with_ui else 'api_server', + command='api_server --with_ui' if with_ui else 'api_server', install_agent_deps=install_agent_deps, service_option=_get_service_option_by_adk_version( adk_version, @@ -999,6 +1001,7 @@ def to_agent_engine( ) agent_config['description'] = description + requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt') if requirements_file: warnings.warn( 'WARNING: `--requirements_file` is deprecated and will be removed in' @@ -1013,6 +1016,14 @@ def to_agent_engine( DeprecationWarning, stacklevel=2, ) + if not os.path.exists(requirements_txt_path): + click.echo(f'Creating {requirements_txt_path}...') + with open(requirements_txt_path, 'w', encoding='utf-8') as f: + f.write('google-cloud-aiplatform[agent_engines]\n') + f.write(f'google-adk=={__version__}\n') + click.echo(f'Using google-adk=={__version__} in requirements') + click.echo(f'Created {requirements_txt_path}') + _ensure_agent_engine_dependency(requirements_txt_path) env_vars = {} if not env_file: # Attempt to read the env variables from .env in the dir (if any). @@ -1300,7 +1311,7 @@ def to_gke( gcp_region=region, app_name=app_name, port=port, - command='web' if with_ui else 'api_server', + command='api_server --with_ui' if with_ui else 'api_server', install_agent_deps=install_agent_deps, service_option=_get_service_option_by_adk_version( adk_version, diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 4307866279..95b445b851 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -23,6 +23,7 @@ import logging import os from pathlib import Path +import sys import tempfile import textwrap @@ -31,17 +32,13 @@ from fastapi import FastAPI import uvicorn -from . import cli_create -from . import cli_deploy from .. import version from ..agents.run_config import StreamingMode from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..features import FeatureName from ..features import override_feature_enabled from .cli import run_cli -from .fast_api import get_fast_api_app from .utils import envs -from .utils import evals from .utils import logs LOG_LEVELS = click.Choice( @@ -488,6 +485,8 @@ def cli_create_cmd( adk create path/to/my_app """ + from . import cli_create + cli_create.run_cmd( app_name, model=model, @@ -649,50 +648,207 @@ def wrapper(*args, **kwargs): ), callback=validate_exclusive, ) +@click.option( + "--state", + type=str, + help="Optional. Initial state for the run as a JSON string.", +) +@click.option( + "--timeout", + type=str, + help="Optional. Timeout for a single turn or query (e.g., 30s, 5m).", +) +@click.option( + "--in_memory", + is_flag=True, + help="Optional. Do not persist session data (use in-memory storage).", +) +@click.option( + "--jsonl", + is_flag=True, + help="Optional. Output structured JSONL instead of human-readable text.", +) +@click.option( + "--default_llm_model", + type=str, + help=( + "Optional. Sets the default LLM model used when the agent does not set" + " a model explicitly." + ), + default=None, +) @click.argument( "agent", type=click.Path( exists=True, dir_okay=True, file_okay=False, resolve_path=True ), ) +@click.argument("query", type=str, required=False) def cli_run( agent: str, + query: Optional[str], save_session: bool, - session_id: str | None, - replay: str | None, - resume: str | None, - session_service_uri: str | None = None, - artifact_service_uri: str | None = None, - memory_service_uri: str | None = None, + session_id: Optional[str], + replay: Optional[str], + resume: Optional[str], + state: Optional[str] = None, + timeout: Optional[str] = None, + in_memory: bool = False, + jsonl: bool = False, + session_service_uri: Optional[str] = None, + artifact_service_uri: Optional[str] = None, + memory_service_uri: Optional[str] = None, use_local_storage: bool = True, + default_llm_model: Optional[str] = None, ): - """Runs an interactive CLI for a certain agent. + """Runs an agent. If no query is provided, enters interactive mode. AGENT: The path to the agent source code folder. + QUERY: Optional. The user message to send to the agent for a single-step run. Example: adk run path/to/my_agent + adk run path/to/my_agent "hello" """ logs.log_to_tmp_folder() agent_parent_folder = os.path.dirname(agent) agent_folder_name = os.path.basename(agent) - asyncio.run( - run_cli( - agent_parent_dir=agent_parent_folder, - agent_folder_name=agent_folder_name, - input_file=replay, - saved_session_file=resume, - save_session=save_session, - session_id=session_id, - session_service_uri=session_service_uri, - artifact_service_uri=artifact_service_uri, - memory_service_uri=memory_service_uri, - use_local_storage=use_local_storage, + # If query is provided, we run in single-step mode (JSONL output) + if query is not None: + from .cli import run_once_cli + + exit_code = asyncio.run( + run_once_cli( + agent_parent_dir=agent_parent_folder, + agent_folder_name=agent_folder_name, + query=query, + state_str=state, + session_id=session_id, + replay=replay, + timeout=timeout, + in_memory=in_memory, + jsonl=jsonl, + session_service_uri=session_service_uri, + artifact_service_uri=artifact_service_uri, + memory_service_uri=memory_service_uri, + use_local_storage=use_local_storage, + default_llm_model=default_llm_model, + ) + ) + sys.exit(exit_code) + else: + # Legacy interactive mode + asyncio.run( + run_cli( + agent_parent_dir=agent_parent_folder, + agent_folder_name=agent_folder_name, + input_file=replay, + saved_session_file=resume, + save_session=save_session, + session_id=session_id, + state_str=state, + timeout=timeout, + in_memory=in_memory, + jsonl=jsonl, + session_service_uri=session_service_uri, + artifact_service_uri=artifact_service_uri, + memory_service_uri=memory_service_uri, + use_local_storage=use_local_storage, + default_llm_model=default_llm_model, + ) + ) + + +@main.command( + "test", + cls=HelpfulCommand, + context_settings={ + "allow_extra_args": True, + "allow_interspersed_args": True, + "ignore_unknown_options": True, + }, +) +@click.argument( + "folder", + type=click.Path( + exists=True, dir_okay=True, file_okay=False, resolve_path=True + ), + default=".", +) +@click.option( + "--rebuild", + is_flag=True, + help="Rebuild test files by running the real agent with user messages.", +) +@click.pass_context +def cli_test(ctx, folder: str, rebuild: bool): + """Runs pytest on agent test JSON files under the specified folder. + + FOLDER: The path to the folder containing agents and tests. + Defaults to the current directory if not specified. + + Example: + adk test path/to/agents + """ + import sys + + if rebuild: + from .agent_test_runner import rebuild_tests + + click.echo(f"Rebuilding tests in {folder}...") + rebuild_tests(folder) + sys.exit(0) + + # Parse arguments to separate pytest args (after --) from regular args + pytest_args = [] + if "--" in ctx.args: + separator_index = ctx.args.index("--") + pytest_args = ctx.args[separator_index + 1 :] + regular_args = ctx.args[:separator_index] + + if regular_args: + click.secho( + "Error: Unexpected arguments after folder and before '--':" + f" {' '.join(regular_args)}. \nOnly arguments after '--' are passed" + " to pytest.", + fg="red", + err=True, ) - ) + ctx.exit(2) + else: + # If no '--', all remaining arguments are passed to pytest + pytest_args = ctx.args + + import subprocess + + os.environ["ADK_TEST_FOLDER"] = folder + + current_dir = Path(__file__).parent + test_runner_path = current_dir / "agent_test_runner.py" + + if not test_runner_path.exists(): + click.secho( + f"Error: Test runner not found at {test_runner_path}", + fg="red", + err=True, + ) + sys.exit(1) + + click.echo(f"Running tests in {folder} using runner {test_runner_path}...") + + result = subprocess.run([ + sys.executable, + "-m", + "pytest", + str(test_runner_path), + "-v", + "-s", + *pytest_args, + ]) + sys.exit(result.returncode) def eval_options(): @@ -843,6 +999,8 @@ def cli_eval( eval_set_results_manager = None if eval_storage_uri: + from .utils import evals + gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( eval_storage_uri ) @@ -1571,11 +1729,34 @@ def wrapper(ctx, *args, **kwargs): return decorator +def _check_windows_reload(reload: bool) -> bool: + """Checks if reload is enabled on Windows and forces it to False if so.""" + if sys.platform == "win32" and reload: + click.secho( + "WARNING: The --reload flag is not supported on Windows because it" + " forces Uvicorn to use SelectorEventLoop, which does not support" + " subprocesses (needed for executing tools). Forcing --no-reload.", + fg="yellow", + err=True, + ) + return False + return reload + + @main.command("web") @feature_options() @fast_api_common_options() @web_options() @adk_services_options(default_use_local_storage=True) +@click.option( + "--default_llm_model", + type=str, + help=( + "Optional. Sets the default LLM model used when the agent does not set" + " a model explicitly." + ), + default=None, +) @click.argument( "agents_dir", type=click.Path( @@ -1585,7 +1766,8 @@ def wrapper(ctx, *args, **kwargs): ) def cli_web( agents_dir: str, - eval_storage_uri: str | None = None, + default_llm_model: Optional[str] = None, + eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: list[str] | None = None, host: str = "127.0.0.1", @@ -1607,13 +1789,15 @@ def cli_web( ): """Starts a FastAPI server with Web UI for agents. - AGENTS_DIR: The directory of agents, where each subdirectory is a single - agent, containing at least `__init__.py` and `agent.py` files. + AGENTS_DIR: The directory of agents (where each subdirectory is a single + agent containing `agent.py` or `root_agent.yaml` files) or a path pointing + directly to a single agent folder. Example: adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ + reload = _check_windows_reload(reload) logs.setup_adk_logger(getattr(logging, log_level.upper())) @asynccontextmanager @@ -1638,6 +1822,8 @@ async def _lifespan(app: FastAPI): fg="green", ) + from .fast_api import get_fast_api_app + app = get_fast_api_app( agents_dir=agents_dir, session_service_uri=session_service_uri, @@ -1659,6 +1845,7 @@ async def _lifespan(app: FastAPI): logo_text=logo_text, logo_image_url=logo_image_url, trigger_sources=trigger_sources, + default_llm_model=default_llm_model, ) config = uvicorn.Config( app, @@ -1692,6 +1879,12 @@ async def _lifespan(app: FastAPI): "Automatically create a session if it doesn't exist when calling /run." ), ) +@click.option( + "--with_ui", + is_flag=True, + default=False, + help="Serve ADK Web UI if set.", +) @click.option( "--gemini_enterprise_app_name", type=str, @@ -1731,18 +1924,21 @@ def cli_api_server( extra_plugins: list[str] | None = None, auto_create_session: bool = False, trigger_sources: list[str] | None = None, + with_ui: bool = False, gemini_enterprise_app_name: str | None = None, express_mode: bool = False, ): """Starts a FastAPI server for agents. - AGENTS_DIR: The directory of agents, where each subdirectory is a single - agent, containing at least `__init__.py` and `agent.py` files. + AGENTS_DIR: The directory of agents (where each subdirectory is a single + agent containing `agent.py` or `root_agent.yaml` files) or a path pointing + directly to a single agent folder. Example: adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir """ + reload = _check_windows_reload(reload) if express_mode and not gemini_enterprise_app_name: raise click.UsageError( "--express_mode is only supported when --gemini_enterprise_app_name is" @@ -1751,6 +1947,8 @@ def cli_api_server( logs.setup_adk_logger(getattr(logging, log_level.upper())) + from .fast_api import get_fast_api_app + config = uvicorn.Config( get_fast_api_app( agents_dir=agents_dir, @@ -1760,7 +1958,7 @@ def cli_api_server( use_local_storage=use_local_storage, eval_storage_uri=eval_storage_uri, allow_origins=allow_origins, - web=False, + web=with_ui, trace_to_cloud=trace_to_cloud, otel_to_cloud=otel_to_cloud, a2a=a2a, @@ -1993,6 +2191,8 @@ def cli_deploy_cloud_run( ctx.exit(2) try: + from . import cli_deploy + cli_deploy.to_cloud_run( agent_folder=agent, project=project, @@ -2293,6 +2493,8 @@ def cli_deploy_agent_engine( "Do not pass both --validate-agent-import and" " --skip-agent-import-validation." ) + from . import cli_deploy + cli_deploy.to_agent_engine( agent_folder=agent, project=project, @@ -2482,6 +2684,8 @@ def cli_deploy_gke( """ try: _warn_if_with_ui(with_ui) + from . import cli_deploy + cli_deploy.to_gke( agent_folder=agent, project=project, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 37aafbee1e..5186a8dd73 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -20,16 +20,18 @@ import logging import os from pathlib import Path -import shutil +import sys from typing import Any from typing import AsyncIterator from typing import Awaitable from typing import Callable from typing import Literal from typing import Mapping +from typing import Optional import click from fastapi import FastAPI +from fastapi import File from fastapi import HTTPException from fastapi import Request from fastapi import UploadFile @@ -48,18 +50,16 @@ from watchdog.observers import Observer from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService -from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager -from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..runners import Runner from ..telemetry._agent_engine import get_propagated_context from ..telemetry._agent_engine import TopSpanProcessor -from .adk_web_server import AdkWebServer +from .api_server import ApiServer +from .dev_server import DevServer from .cli_deploy import _AGENT_ENGINE_CLASS_METHODS from .service_registry import load_services_module from .utils import envs -from .utils import evals from .utils.agent_change_handler import AgentChangeEventHandler -from .utils.agent_loader import AgentLoader +from .utils.agent_loader import is_single_agent_directory from .utils.base_agent_loader import BaseAgentLoader from .utils.service_factory import _create_task_store_from_options from .utils.service_factory import create_artifact_service_from_options @@ -96,6 +96,309 @@ def __getattr__(name: str): return attr +def _register_builder_endpoints(app: FastAPI, web: bool, agents_dir: str): + """Registers builder endpoints if web is enabled and multipart is installed.""" + if not web: + return + try: + import multipart + except ImportError: + logger.warning( + "python-multipart not installed. Builder UI endpoints will not be" + " available." + ) + return + + import shutil + + import yaml + + agents_base_path = (Path.cwd() / agents_dir).resolve() + + def _get_app_root(app_name: str) -> Path: + if app_name in ("", ".", ".."): + raise ValueError(f"Invalid app name: {app_name!r}") + if Path(app_name).name != app_name or "\\" in app_name: + raise ValueError(f"Invalid app name: {app_name!r}") + app_root = (agents_base_path / app_name).resolve() + if not app_root.is_relative_to(agents_base_path): + raise ValueError(f"Invalid app name: {app_name!r}") + return app_root + + def _normalize_relative_path(path: str) -> str: + return path.replace("\\", "/").lstrip("/") + + def _has_parent_reference(path: str) -> bool: + return any(part == ".." for part in path.split("/")) + + _ALLOWED_EXTENSIONS = frozenset({".yaml", ".yml"}) + + _BLOCKED_YAML_KEYS = frozenset({"args"}) + + def _check_yaml_for_blocked_keys(content: bytes, filename: str) -> None: + try: + docs = list(yaml.safe_load_all(content)) + except yaml.YAMLError as exc: + raise ValueError(f"Invalid YAML in {filename!r}: {exc}") from exc + + def _walk(node: Any) -> None: + if isinstance(node, dict): + for key, value in node.items(): + if key in _BLOCKED_YAML_KEYS: + raise ValueError( + f"Blocked key {key!r} found in {filename!r}. " + f"The '{key}' field is not allowed in builder uploads " + "because it can execute arbitrary code." + ) + _walk(value) + elif isinstance(node, list): + for item in node: + _walk(item) + + for doc in docs: + _walk(doc) + + def _parse_upload_filename(filename: Optional[str]) -> tuple[str, str]: + if not filename: + raise ValueError("Upload filename is missing.") + filename = _normalize_relative_path(filename) + if "/" not in filename: + raise ValueError(f"Invalid upload filename: {filename!r}") + app_name, rel_path = filename.split("/", 1) + if not app_name or not rel_path: + raise ValueError(f"Invalid upload filename: {filename!r}") + if rel_path.startswith("/"): + raise ValueError(f"Absolute upload path rejected: {filename!r}") + if _has_parent_reference(rel_path): + raise ValueError(f"Path traversal rejected: {filename!r}") + ext = os.path.splitext(rel_path)[1].lower() + if ext not in _ALLOWED_EXTENSIONS: + raise ValueError( + f"File type not allowed: {rel_path!r}" + f" (allowed: {', '.join(sorted(_ALLOWED_EXTENSIONS))})" + ) + return app_name, rel_path + + def _parse_file_path(file_path: str) -> str: + file_path = _normalize_relative_path(file_path) + if not file_path: + raise ValueError("file_path is missing.") + if file_path.startswith("/"): + raise ValueError(f"Absolute file_path rejected: {file_path!r}") + if _has_parent_reference(file_path): + raise ValueError(f"Path traversal rejected: {file_path!r}") + ext = os.path.splitext(file_path)[1].lower() + if ext not in _ALLOWED_EXTENSIONS: + raise ValueError( + f"File type not allowed: {file_path!r}" + f" (allowed: {', '.join(sorted(_ALLOWED_EXTENSIONS))})" + ) + return file_path + + def _resolve_under_dir(root_dir: Path, rel_path: str) -> Path: + file_path = root_dir / rel_path + resolved_root_dir = root_dir.resolve() + resolved_file_path = file_path.resolve() + if not resolved_file_path.is_relative_to(resolved_root_dir): + raise ValueError(f"Path escapes root_dir: {rel_path!r}") + return file_path + + def _get_tmp_agent_root(app_root: Path, app_name: str) -> Path: + tmp_agent_root = app_root / "tmp" / app_name + resolved_tmp_agent_root = tmp_agent_root.resolve() + if not resolved_tmp_agent_root.is_relative_to(app_root): + raise ValueError(f"Invalid tmp path for app: {app_name!r}") + return tmp_agent_root + + def copy_dir_contents(source_dir: Path, dest_dir: Path) -> None: + dest_dir.mkdir(parents=True, exist_ok=True) + for source_path in source_dir.iterdir(): + if source_path.name == "tmp": + continue + + dest_path = dest_dir / source_path.name + if source_path.is_dir(): + if dest_path.exists() and dest_path.is_file(): + dest_path.unlink() + shutil.copytree(source_path, dest_path, dirs_exist_ok=True) + elif source_path.is_file(): + if dest_path.exists() and dest_path.is_dir(): + shutil.rmtree(dest_path) + shutil.copy2(source_path, dest_path) + + def cleanup_tmp(app_name: str) -> bool: + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in cleanup_tmp: %s", exc) + return False + + try: + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + except ValueError as exc: + logger.exception("Error in cleanup_tmp: %s", exc) + return False + + try: + shutil.rmtree(tmp_agent_root) + except FileNotFoundError: + pass + except OSError as exc: + logger.exception("Error deleting tmp agent root: %s", exc) + return False + + tmp_dir = app_root / "tmp" + resolved_tmp_dir = tmp_dir.resolve() + if not resolved_tmp_dir.is_relative_to(app_root): + logger.error( + "Refusing to delete tmp outside app_root: %s", resolved_tmp_dir + ) + return False + + try: + tmp_dir.rmdir() + except OSError: + pass + + return True + + def ensure_tmp_exists(app_name: str) -> bool: + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + if not app_root.is_dir(): + return False + + try: + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + except ValueError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + if tmp_agent_root.exists(): + return True + + try: + tmp_agent_root.mkdir(parents=True, exist_ok=True) + copy_dir_contents(app_root, tmp_agent_root) + except OSError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + return True + + @app.post("/builder/save", response_model_exclude_none=True) + async def builder_build( + files: list[UploadFile] = File(...), tmp: Optional[bool] = False + ) -> bool: + try: + app_names: set[str] = set() + uploads: list[tuple[str, bytes]] = [] + for file in files: + app_name, rel_path = _parse_upload_filename(file.filename) + app_names.add(app_name) + content = await file.read() + uploads.append((rel_path, content)) + + if len(app_names) != 1: + logger.error( + "Exactly one app name is required, found: %s", + sorted(app_names), + ) + return False + + app_name = next(iter(app_names)) + + for rel_path, content in uploads: + _check_yaml_for_blocked_keys(content, f"{app_name}/{rel_path}") + + if tmp: + app_root = _get_app_root(app_name) + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + tmp_agent_root.mkdir(parents=True, exist_ok=True) + + for rel_path, content in uploads: + destination_path = _resolve_under_dir(tmp_agent_root, rel_path) + destination_path.parent.mkdir(parents=True, exist_ok=True) + destination_path.write_bytes(content) + + return True + + app_root = _get_app_root(app_name) + app_root.mkdir(parents=True, exist_ok=True) + + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + if tmp_agent_root.is_dir(): + copy_dir_contents(tmp_agent_root, app_root) + + for rel_path, content in uploads: + destination_path = _resolve_under_dir(app_root, rel_path) + destination_path.parent.mkdir(parents=True, exist_ok=True) + destination_path.write_bytes(content) + + return cleanup_tmp(app_name) + except ValueError as exc: + logger.exception("Error in builder_build: %s", exc) + raise HTTPException(status_code=400, detail=str(exc)) + except OSError as exc: + logger.exception("Error in builder_build: %s", exc) + return False + + @app.post("/builder/app/{app_name}/cancel", response_model_exclude_none=True) + async def builder_cancel(app_name: str) -> bool: + return cleanup_tmp(app_name) + + @app.get( + "/builder/app/{app_name}", + response_model_exclude_none=True, + response_class=PlainTextResponse, + ) + async def get_agent_builder( + app_name: str, + file_path: Optional[str] = None, + tmp: Optional[bool] = False, + ): + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) + return "" + + agent_dir = app_root + if tmp: + if not ensure_tmp_exists(app_name): + return "" + agent_dir = app_root / "tmp" / app_name + + if not file_path: + rel_path = "root_agent.yaml" + else: + try: + rel_path = _parse_file_path(file_path) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) + return "" + + try: + agent_file_path = _resolve_under_dir(agent_dir, rel_path) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) + return "" + + if not agent_file_path.is_file(): + return "" + + return FileResponse( + path=agent_file_path, + media_type="application/x-yaml", + filename=file_path or f"{app_name}.yaml", + headers={"Cache-Control": "no-store"}, + ) + + def get_fast_api_app( *, agents_dir: str, @@ -122,6 +425,7 @@ def get_fast_api_app( logo_image_url: str | None = None, auto_create_session: bool = False, trigger_sources: list[Literal["pubsub", "eventarc"]] | None = None, + default_llm_model: str | None = None, gemini_enterprise_app_name: str | None = None, express_mode: bool = False, ) -> FastAPI: @@ -172,41 +476,47 @@ def get_fast_api_app( trigger_sources: List of trigger sources to enable (e.g. ["pubsub", "eventarc"]). When set, registers /trigger/* endpoints for batch and event-driven agent invocations. None disables all trigger endpoints. - gemini_enterprise_app_name: The app_name to register with Gemini Enterprise - via https://docs.cloud.google.com/gemini/enterprise/docs/register-and-manage-an-adk-agent - express_mode: Whether or not to initialize the server in express mode. - This is only supported when gemini_enterprise_app_name is set. Defaults to - False. + default_llm_model: Default LLM model to use for the agent. + gemini_enterprise_app_name: The Gemini Enterprise app name to use for the + agent. + express_mode: Whether to enable express mode. Returns: The configured FastAPI application instance. """ - if express_mode and not gemini_enterprise_app_name: - raise ValueError( - "express_mode is only supported when gemini_enterprise_app_name is set." - ) + # Detect single agent mode + agents_path = Path(agents_dir).resolve() + is_single_agent = is_single_agent_directory(agents_path) - # Enable denylist enforcement for config loads if web UI is enabled. - if web: - from ..agents import config_agent_utils - - config_agent_utils._set_enforce_denylist(True) + original_agents_dir = agents_dir + single_agent_name = None + if is_single_agent: + single_agent_name = agents_path.name + agents_dir = str(agents_path.parent) # Set up eval managers. if eval_storage_uri: + from .utils import evals + gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( eval_storage_uri ) eval_sets_manager = gcs_eval_managers.eval_sets_manager eval_set_results_manager = gcs_eval_managers.eval_set_results_manager else: - eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) - eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) + this_module = sys.modules[__name__] + eval_sets_manager = this_module.LocalEvalSetsManager(agents_dir=agents_dir) + eval_set_results_manager = this_module.LocalEvalSetResultsManager( + agents_dir=agents_dir + ) # initialize Agent Loader if not passed as argument + this_module = sys.modules[__name__] if agent_loader is None: - agent_loader = AgentLoader(agents_dir) + agent_loader = this_module.AgentLoader(original_agents_dir) + elif is_single_agent and isinstance(agent_loader, this_module.AgentLoader): + agent_loader._set_single_agent_mode(single_agent_name, agents_dir) # Load services.py from agents_dir for custom service registration. load_services_module(agents_dir) @@ -242,7 +552,12 @@ def get_fast_api_app( # Build the Credential service credential_service = InMemoryCredentialService() - adk_web_server = AdkWebServer( + # Instantiate the appropriate server class based on web option + # If web=True, use DevServer (includes all endpoints: production + dev) + # If web=False, use ApiServer (production-safe endpoints only) + ServerClass = DevServer if web else ApiServer + + adk_web_server = ServerClass( agent_loader=agent_loader, session_service=session_service, artifact_service=artifact_service, @@ -257,30 +572,16 @@ def get_fast_api_app( url_prefix=url_prefix, auto_create_session=auto_create_session, trigger_sources=trigger_sources, + default_llm_model=default_llm_model, ) + # In single agent mode, use that agent as the default app. + if is_single_agent: + adk_web_server.default_app_name = single_agent_name + # Callbacks & other optional args for when constructing the FastAPI instance extra_fast_api_args = {} - # Synchronize otel_to_cloud and GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY. - # This is to support toggling telemetry in the Agent Platform Console, which - # sets the environment variable GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY. - if otel_to_cloud: - os.environ["GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"] = "true" - logger.info( - "Setting GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY to true based on" - " otel_to_cloud flag." - ) - elif gemini_enterprise_app_name and os.environ.get( - "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" - ): - logger.info( - "Setting otel_to_cloud to True for Gemini Enterprise app %s based on" - " GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY environment variable.", - gemini_enterprise_app_name, - ) - otel_to_cloud = True - # TODO - Remove separate trace_to_cloud logic once otel_to_cloud stops being # EXPERIMENTAL. if trace_to_cloud and not otel_to_cloud: @@ -305,7 +606,7 @@ def register_processors(provider: TracerProvider) -> None: if reload_agents: - def setup_observer(observer: Observer, adk_web_server: AdkWebServer): + def setup_observer(observer: Observer, adk_web_server: ApiServer): agent_change_handler = AgentChangeEventHandler( agent_loader=agent_loader, runners_to_clean=adk_web_server.runners_to_clean, @@ -314,7 +615,7 @@ def setup_observer(observer: Observer, adk_web_server: AdkWebServer): observer.schedule(agent_change_handler, agents_dir, recursive=True) observer.start() - def tear_down_observer(observer: Observer, _: AdkWebServer): + def tear_down_observer(observer: Observer, _: ApiServer): observer.stop() observer.join() @@ -365,309 +666,7 @@ async def _a2a_lifespan(app_instance: FastAPI): ) # --- Builder endpoints (agent editor UI) --- - # Only register when the web UI is enabled. In headless / production - # deployments (e.g. `adk deploy cloud_run`) these endpoints are unnecessary - # and expose an attack surface that allows arbitrary file writes under the - # agents directory. - # See https://github.com/google/adk-python/issues/4947 - if web: - agents_base_path = (Path.cwd() / agents_dir).resolve() - - def _get_app_root(app_name: str) -> Path: - if app_name in ("", ".", ".."): - raise ValueError(f"Invalid app name: {app_name!r}") - if Path(app_name).name != app_name or "\\" in app_name: - raise ValueError(f"Invalid app name: {app_name!r}") - app_root = (agents_base_path / app_name).resolve() - if not app_root.is_relative_to(agents_base_path): - raise ValueError(f"Invalid app name: {app_name!r}") - return app_root - - def _normalize_relative_path(path: str) -> str: - return path.replace("\\", "/").lstrip("/") - - def _has_parent_reference(path: str) -> bool: - return any(part == ".." for part in path.split("/")) - - _ALLOWED_EXTENSIONS = frozenset({".yaml", ".yml"}) - - # --- YAML content security --- - # The `args` key in agent YAML configs (CodeConfig.args, ToolConfig.args) - # allows callers to pass arbitrary arguments to Python constructors and - # functions, which is an RCE vector when exposed through the builder UI. - # Block any upload that contains an `args` key anywhere in the document. - _BLOCKED_YAML_KEYS = frozenset({"args"}) - - def _check_yaml_for_blocked_keys(content: bytes, filename: str) -> None: - """Raise if the YAML document contains any blocked keys.""" - import yaml - - try: - docs = list(yaml.safe_load_all(content)) - except yaml.YAMLError as exc: - raise ValueError(f"Invalid YAML in {filename!r}: {exc}") from exc - - def _walk(node: Any) -> None: - if isinstance(node, dict): - for key, value in node.items(): - if key in _BLOCKED_YAML_KEYS: - raise ValueError( - f"Blocked key {key!r} found in {filename!r}. " - f"The '{key}' field is not allowed in builder uploads " - "because it can execute arbitrary code." - ) - _walk(value) - elif isinstance(node, list): - for item in node: - _walk(item) - - for doc in docs: - _walk(doc) - - def _parse_upload_filename(filename: str | None) -> tuple[str, str]: - if not filename: - raise ValueError("Upload filename is missing.") - filename = _normalize_relative_path(filename) - if "/" not in filename: - raise ValueError(f"Invalid upload filename: {filename!r}") - app_name, rel_path = filename.split("/", 1) - if not app_name or not rel_path: - raise ValueError(f"Invalid upload filename: {filename!r}") - if rel_path.startswith("/"): - raise ValueError(f"Absolute upload path rejected: {filename!r}") - if _has_parent_reference(rel_path): - raise ValueError(f"Path traversal rejected: {filename!r}") - ext = os.path.splitext(rel_path)[1].lower() - if ext not in _ALLOWED_EXTENSIONS: - raise ValueError( - f"File type not allowed: {rel_path!r}" - f" (allowed: {', '.join(sorted(_ALLOWED_EXTENSIONS))})" - ) - return app_name, rel_path - - def _parse_file_path(file_path: str) -> str: - file_path = _normalize_relative_path(file_path) - if not file_path: - raise ValueError("file_path is missing.") - if file_path.startswith("/"): - raise ValueError(f"Absolute file_path rejected: {file_path!r}") - if _has_parent_reference(file_path): - raise ValueError(f"Path traversal rejected: {file_path!r}") - ext = os.path.splitext(file_path)[1].lower() - if ext not in _ALLOWED_EXTENSIONS: - raise ValueError( - f"File type not allowed: {file_path!r}" - f" (allowed: {', '.join(sorted(_ALLOWED_EXTENSIONS))})" - ) - return file_path - - def _resolve_under_dir(root_dir: Path, rel_path: str) -> Path: - file_path = root_dir / rel_path - resolved_root_dir = root_dir.resolve() - resolved_file_path = file_path.resolve() - if not resolved_file_path.is_relative_to(resolved_root_dir): - raise ValueError(f"Path escapes root_dir: {rel_path!r}") - return file_path - - def _get_tmp_agent_root(app_root: Path, app_name: str) -> Path: - tmp_agent_root = app_root / "tmp" / app_name - resolved_tmp_agent_root = tmp_agent_root.resolve() - if not resolved_tmp_agent_root.is_relative_to(app_root): - raise ValueError(f"Invalid tmp path for app: {app_name!r}") - return tmp_agent_root - - def copy_dir_contents(source_dir: Path, dest_dir: Path) -> None: - dest_dir.mkdir(parents=True, exist_ok=True) - for source_path in source_dir.iterdir(): - if source_path.name == "tmp": - continue - - dest_path = dest_dir / source_path.name - if source_path.is_dir(): - if dest_path.exists() and dest_path.is_file(): - dest_path.unlink() - shutil.copytree(source_path, dest_path, dirs_exist_ok=True) - elif source_path.is_file(): - if dest_path.exists() and dest_path.is_dir(): - shutil.rmtree(dest_path) - shutil.copy2(source_path, dest_path) - - def cleanup_tmp(app_name: str) -> bool: - try: - app_root = _get_app_root(app_name) - except ValueError as exc: - logger.exception("Error in cleanup_tmp: %s", exc) - return False - - try: - tmp_agent_root = _get_tmp_agent_root(app_root, app_name) - except ValueError as exc: - logger.exception("Error in cleanup_tmp: %s", exc) - return False - - try: - shutil.rmtree(tmp_agent_root) - except FileNotFoundError: - pass - except OSError as exc: - logger.exception("Error deleting tmp agent root: %s", exc) - return False - - tmp_dir = app_root / "tmp" - resolved_tmp_dir = tmp_dir.resolve() - if not resolved_tmp_dir.is_relative_to(app_root): - logger.error( - "Refusing to delete tmp outside app_root: %s", resolved_tmp_dir - ) - return False - - try: - tmp_dir.rmdir() - except OSError: - pass - - return True - - def ensure_tmp_exists(app_name: str) -> bool: - try: - app_root = _get_app_root(app_name) - except ValueError as exc: - logger.exception("Error in ensure_tmp_exists: %s", exc) - return False - - if not app_root.is_dir(): - return False - - try: - tmp_agent_root = _get_tmp_agent_root(app_root, app_name) - except ValueError as exc: - logger.exception("Error in ensure_tmp_exists: %s", exc) - return False - - if tmp_agent_root.exists(): - return True - - try: - tmp_agent_root.mkdir(parents=True, exist_ok=True) - copy_dir_contents(app_root, tmp_agent_root) - except OSError as exc: - logger.exception("Error in ensure_tmp_exists: %s", exc) - return False - - return True - - @app.post("/builder/save", response_model_exclude_none=True) - async def builder_build( - files: list[UploadFile], tmp: bool | None = False - ) -> bool: - try: - # Phase 1: parse filenames and read content into memory. - app_names: set[str] = set() - uploads: list[tuple[str, bytes]] = [] - for file in files: - app_name, rel_path = _parse_upload_filename(file.filename) - app_names.add(app_name) - content = await file.read() - uploads.append((rel_path, content)) - - if len(app_names) != 1: - logger.error( - "Exactly one app name is required, found: %s", - sorted(app_names), - ) - return False - - app_name = next(iter(app_names)) - - # Phase 2: validate every file *before* writing anything to disk. - for rel_path, content in uploads: - _check_yaml_for_blocked_keys(content, f"{app_name}/{rel_path}") - - # Phase 3: write validated files to disk. - if tmp: - app_root = _get_app_root(app_name) - tmp_agent_root = _get_tmp_agent_root(app_root, app_name) - tmp_agent_root.mkdir(parents=True, exist_ok=True) - - for rel_path, content in uploads: - destination_path = _resolve_under_dir(tmp_agent_root, rel_path) - destination_path.parent.mkdir(parents=True, exist_ok=True) - destination_path.write_bytes(content) - - return True - - app_root = _get_app_root(app_name) - app_root.mkdir(parents=True, exist_ok=True) - - tmp_agent_root = _get_tmp_agent_root(app_root, app_name) - if tmp_agent_root.is_dir(): - copy_dir_contents(tmp_agent_root, app_root) - - for rel_path, content in uploads: - destination_path = _resolve_under_dir(app_root, rel_path) - destination_path.parent.mkdir(parents=True, exist_ok=True) - destination_path.write_bytes(content) - - return cleanup_tmp(app_name) - except ValueError as exc: - logger.exception("Error in builder_build: %s", exc) - raise HTTPException(status_code=400, detail=str(exc)) - except OSError as exc: - logger.exception("Error in builder_build: %s", exc) - return False - - @app.post( - "/builder/app/{app_name}/cancel", response_model_exclude_none=True - ) - async def builder_cancel(app_name: str) -> bool: - return cleanup_tmp(app_name) - - @app.get( - "/builder/app/{app_name}", - response_model_exclude_none=True, - response_class=PlainTextResponse, - ) - async def get_agent_builder( - app_name: str, - file_path: str | None = None, - tmp: bool | None = False, - ): - try: - app_root = _get_app_root(app_name) - except ValueError as exc: - logger.exception("Error in get_agent_builder: %s", exc) - return "" - - agent_dir = app_root - if tmp: - if not ensure_tmp_exists(app_name): - return "" - agent_dir = app_root / "tmp" / app_name - - if not file_path: - rel_path = "root_agent.yaml" - else: - try: - rel_path = _parse_file_path(file_path) - except ValueError as exc: - logger.exception("Error in get_agent_builder: %s", exc) - return "" - - try: - agent_file_path = _resolve_under_dir(agent_dir, rel_path) - except ValueError as exc: - logger.exception("Error in get_agent_builder: %s", exc) - return "" - - if not agent_file_path.is_file(): - return "" - - return FileResponse( - path=agent_file_path, - media_type="application/x-yaml", - filename=file_path or f"{app_name}.yaml", - headers={"Cache-Control": "no-store"}, - ) + _register_builder_endpoints(app, web, agents_dir) if a2a and a2a_task_store is not None: from a2a.server.apps import A2AStarletteApplication diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index b7911dc1f1..3fd942cd16 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -113,7 +113,7 @@ def _event_state_delta(state_delta: dict[str, Any]): # Define mocked async generator functions for the Runner -async def dummy_run_live(self, session, live_request_queue): +async def dummy_run_live(self, session, live_request_queue, **kwargs): yield _event_1() await asyncio.sleep(0) @@ -1443,6 +1443,45 @@ async def run_async_capture( assert captured_invocation_id["invocation_id"] == payload["invocation_id"] +def test_agent_run_passes_custom_metadata( + test_app, create_test_session, monkeypatch +): + """Test /run forwards custom_metadata via the run config.""" + info = create_test_session + captured: dict[str, Optional[RunConfig]] = {"run_config": None} + + async def run_async_capture( + self, + *, + user_id: str, + session_id: str, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, + state_delta: Optional[dict[str, Any]] = None, + run_config: Optional[RunConfig] = None, + ): + del self, user_id, session_id, invocation_id, new_message, state_delta + captured["run_config"] = run_config + yield _event_1() + + monkeypatch.setattr(Runner, "run_async", run_async_capture) + + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, + "streaming": False, + "custom_metadata": {"tenant": "acme", "trace": "abc123"}, + } + + response = test_app.post("/run", json=payload) + + assert response.status_code == 200 + assert captured["run_config"] is not None + assert captured["run_config"].custom_metadata == payload["custom_metadata"] + + def test_agent_run_sse_splits_artifact_delta( test_app, create_test_session, monkeypatch ): @@ -1745,10 +1784,10 @@ def test_get_eval_set_result_not_found(test_app): assert response.status_code == 404 -def test_list_metrics_info(test_app): +def test_list_metrics_info(builder_test_client): """Test listing metrics info.""" - url = "/apps/test_app/metrics-info" - response = test_app.get(url) + url = "/dev/apps/test_app/metrics-info" + response = builder_test_client.get(url) # Verify the response assert response.status_code == 200 @@ -1768,7 +1807,7 @@ def test_debug_trace(test_app): """Test the debug trace endpoint.""" # This test will likely return 404 since we haven't set up trace data, # but it tests that the endpoint exists and handles missing traces correctly. - url = "/debug/trace/nonexistent-event" + url = "/dev/apps/test_app/debug/trace/nonexistent-event" response = test_app.get(url) # Verify we get a 404 for a nonexistent trace @@ -1783,56 +1822,6 @@ def test_openapi_json_schema_accessible(test_app): logger.info("OpenAPI /openapi.json endpoint is accessible") -def test_get_event_graph_returns_dot_src_for_app_agent(): - """Ensure graph endpoint unwraps App instances before building the graph.""" - from google.adk.cli.adk_web_server import AdkWebServer - - root_agent = DummyAgent(name="dummy_agent") - app_agent = App(name="test_app", root_agent=root_agent) - - class Loader: - - def load_agent(self, app_name): - return app_agent - - def list_agents(self): - return [app_agent.name] - - session_service = AsyncMock() - session = Session( - id="session_id", - app_name="test_app", - user_id="user", - state={}, - events=[Event(author="dummy_agent")], - ) - event_id = session.events[0].id - session_service.get_session.return_value = session - - adk_web_server = AdkWebServer( - agent_loader=Loader(), - session_service=session_service, - memory_service=MagicMock(), - artifact_service=MagicMock(), - credential_service=MagicMock(), - eval_sets_manager=MagicMock(), - eval_set_results_manager=MagicMock(), - agents_dir=".", - ) - - fast_api_app = adk_web_server.get_fast_api_app( - setup_observer=lambda _observer, _server: None, - tear_down_observer=lambda _observer, _server: None, - ) - - client = TestClient(fast_api_app) - response = client.get( - f"/apps/test_app/users/user/sessions/session_id/events/{event_id}/graph" - ) - assert response.status_code == 200 - assert "dotSrc" in response.json() - - def test_a2a_agent_discovery(test_app_with_a2a): """Test that A2A agents are properly discovered and configured.""" # This test mainly verifies that the A2A setup doesn't break the app @@ -2194,12 +2183,14 @@ def test_builder_final_save_preserves_files_and_cleans_tmp( ("app/sub_agent.yaml", b"name: sub\n", "application/x-yaml"), ), ] - response = builder_test_client.post("/builder/save?tmp=true", files=files) + response = builder_test_client.post( + "/dev/apps/app/builder/save?tmp=true", files=files + ) assert response.status_code == 200 assert response.json() is True response = builder_test_client.post( - "/builder/save", + "/dev/apps/app/builder/save", files=[( "files", ( @@ -2220,7 +2211,7 @@ def test_builder_final_save_preserves_files_and_cleans_tmp( def test_builder_save_rejects_cross_origin_post(builder_test_client, tmp_path): response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", headers={"origin": "https://evil.com"}, files=[( "files", @@ -2235,7 +2226,7 @@ def test_builder_save_rejects_cross_origin_post(builder_test_client, tmp_path): def test_builder_save_allows_same_origin_post(builder_test_client, tmp_path): response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", headers={"origin": "http://testserver"}, files=[( "files", @@ -2250,7 +2241,7 @@ def test_builder_save_allows_same_origin_post(builder_test_client, tmp_path): def test_builder_get_allows_cross_origin_get(builder_test_client): response = builder_test_client.get( - "/builder/app/missing?tmp=true", + "/dev/apps/missing/builder?tmp=true", headers={"origin": "https://evil.com"}, ) @@ -2263,12 +2254,12 @@ def test_builder_cancel_deletes_tmp_idempotent(builder_test_client, tmp_path): tmp_agent_root.mkdir(parents=True, exist_ok=True) (tmp_agent_root / "root_agent.yaml").write_text("name: app\n") - response = builder_test_client.post("/builder/app/app/cancel") + response = builder_test_client.post("/dev/apps/app/builder/cancel") assert response.status_code == 200 assert response.json() is True assert not (tmp_path / "app" / "tmp").exists() - response = builder_test_client.post("/builder/app/app/cancel") + response = builder_test_client.post("/dev/apps/app/builder/cancel") assert response.status_code == 200 assert response.json() is True assert not (tmp_path / "app" / "tmp").exists() @@ -2283,7 +2274,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): (nested_dir / "nested.yaml").write_text("nested: true\n") assert not (app_root / "tmp").exists() - response = builder_test_client.get("/builder/app/app?tmp=true") + response = builder_test_client.get("/dev/apps/app/builder?tmp=true") assert response.status_code == 200 assert response.text == "name: app\n" @@ -2292,7 +2283,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): assert (tmp_agent_root / "nested" / "nested.yaml").is_file() response = builder_test_client.get( - "/builder/app/app?tmp=true&file_path=nested/nested.yaml" + "/dev/apps/app/builder?tmp=true&file_path=nested/nested.yaml" ) assert response.status_code == 200 assert response.text == "nested: true\n" @@ -2301,7 +2292,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): def test_builder_get_tmp_true_missing_app_returns_empty( builder_test_client, tmp_path ): - response = builder_test_client.get("/builder/app/missing?tmp=true") + response = builder_test_client.get("/dev/apps/missing/builder?tmp=true") assert response.status_code == 200 assert response.text == "" assert not (tmp_path / "missing").exists() @@ -2309,7 +2300,7 @@ def test_builder_get_tmp_true_missing_app_returns_empty( def test_builder_save_rejects_traversal(builder_test_client, tmp_path): response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", files=[( "files", ("app/../escape.yaml", b"nope\n", "application/x-yaml"), @@ -2323,7 +2314,7 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): def test_builder_save_rejects_py_files(builder_test_client, tmp_path): """Uploading .py files via /builder/save is rejected.""" response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", files=[( "files", ("app/agent.py", b"import os\nos.system('id')\n", "text/plain"), @@ -2345,7 +2336,7 @@ def test_builder_save_rejects_non_yaml_extensions( (".pth", b"import os"), ]: response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", files=[( "files", (f"app/file{ext}", content, "application/octet-stream"), @@ -2357,7 +2348,7 @@ def test_builder_save_rejects_non_yaml_extensions( def test_builder_save_allows_yaml_files(builder_test_client, tmp_path): """Uploading .yaml and .yml files is allowed.""" response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", files=[( "files", ("app/root_agent.yaml", b"name: app\n", "application/x-yaml"), @@ -2367,7 +2358,7 @@ def test_builder_save_allows_yaml_files(builder_test_client, tmp_path): assert response.json() is True response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", files=[( "files", ("app/sub_agent.yml", b"name: sub\n", "application/x-yaml"), @@ -2385,7 +2376,7 @@ def test_builder_save_rejects_args_key(builder_test_client, tmp_path): key: value """ response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", files=[( "files", ("app/root_agent.yaml", yaml_with_args, "application/x-yaml"), @@ -2405,7 +2396,7 @@ def test_builder_save_rejects_nested_args_key(builder_test_client, tmp_path): param: value """ response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", files=[( "files", ("app/root_agent.yaml", yaml_with_nested_args, "application/x-yaml"), @@ -2416,7 +2407,7 @@ def test_builder_save_rejects_nested_args_key(builder_test_client, tmp_path): def test_builder_get_rejects_non_yaml_file_paths(builder_test_client, tmp_path): - """GET /builder/app/{app_name}?file_path=... rejects non-YAML extensions.""" + """GET /dev/apps/{app_name}/builder?file_path=... rejects non-YAML extensions.""" app_root = tmp_path / "app" app_root.mkdir(parents=True, exist_ok=True) (app_root / ".env").write_text("SECRET=supersecret\n") @@ -2425,26 +2416,26 @@ def test_builder_get_rejects_non_yaml_file_paths(builder_test_client, tmp_path): for file_path in [".env", "agent.py", "config.json"]: response = builder_test_client.get( - f"/builder/app/app?file_path={file_path}" + f"/dev/apps/app/builder?file_path={file_path}" ) assert response.status_code == 200, f"Expected 200 for {file_path}" assert response.text == "", f"Expected empty response for {file_path}" def test_builder_get_allows_yaml_file_paths(builder_test_client, tmp_path): - """GET /builder/app/{app_name}?file_path=... allows YAML extensions.""" + """GET /dev/apps/{app_name}/builder?file_path=... allows YAML extensions.""" app_root = tmp_path / "app" app_root.mkdir(parents=True, exist_ok=True) (app_root / "sub_agent.yaml").write_text("name: sub\n") (app_root / "tool.yml").write_text("name: tool\n") response = builder_test_client.get( - "/builder/app/app?file_path=sub_agent.yaml" + "/dev/apps/app/builder?file_path=sub_agent.yaml" ) assert response.status_code == 200 assert response.text == "name: sub\n" - response = builder_test_client.get("/builder/app/app?file_path=tool.yml") + response = builder_test_client.get("/dev/apps/app/builder?file_path=tool.yml") assert response.status_code == 200 assert response.text == "name: tool\n" @@ -2467,28 +2458,28 @@ def test_builder_endpoints_not_registered_without_web( mock_eval_set_results_manager, web=False, ) - # /builder/save should return 404/405, not 200 + # /dev/apps/app/builder/save should return 404/405, not 200 response = client.post( - "/builder/save", + "/dev/apps/app/builder/save", files=[ ("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml")) ], ) assert response.status_code in (404, 405) - # /builder/app/{name}/cancel should also be absent - response = client.post("/builder/app/app/cancel") + # /dev/apps/{name}/builder/cancel should also be absent + response = client.post("/dev/apps/app/builder/cancel") assert response.status_code in (404, 405) - # /builder/app/{name} GET should also be absent - response = client.get("/builder/app/app") + # /dev/apps/{name}/builder GET should also be absent + response = client.get("/dev/apps/app/builder") assert response.status_code in (404, 405) def test_builder_endpoints_registered_with_web(builder_test_client): """Builder endpoints are available when web=True.""" response = builder_test_client.post( - "/builder/save?tmp=true", + "/dev/apps/app/builder/save?tmp=true", files=[ ("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml")) ], @@ -2740,6 +2731,401 @@ async def run_async_capture( assert captured_visual_builder_values.get("yaml_app_after_sleep") == True +def test_default_app_name_middleware_and_resolution( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + monkeypatch, +): + """Test that when ADK_DEFAULT_APP_NAME is set, path rewriting works for get_session and run.""" + # Set environment variable + monkeypatch.setenv("ADK_DEFAULT_APP_NAME", "test_app") + + test_app = _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + ) + + # Create session for test_app + async def setup_session(): + await mock_session_service.create_session( + app_name="test_app", + user_id="test_user", + session_id="test_session", + state={}, + ) + + asyncio.run(setup_session()) + + # 1. Test path rewriting for GET /users/{user_id}/sessions/{session_id} + response = test_app.get("/users/test_user/sessions/test_session") + assert response.status_code == 200 + assert response.json()["id"] == "test_session" + + # 2. Test app_name omission in /run request payload + payload = { + "user_id": "test_user", + "session_id": "test_session", + "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, + } + response = test_app.post("/run", json=payload) + assert response.status_code == 200 + assert isinstance(response.json(), list) + + +def test_default_app_name_not_set_raises_error(test_app, monkeypatch): + """Test that omitting app_name when ADK_DEFAULT_APP_NAME is not set raises 400/404.""" + # Make sure environment variable is NOT set + monkeypatch.delenv("ADK_DEFAULT_APP_NAME", raising=False) + + # 1. Accessing /users/{user_id}/sessions/{session_id} should return 404 because no rewrite happened + response = test_app.get("/users/test_user/sessions/test_session") + assert response.status_code == 404 + + # 2. Accessing /run with omitted app_name should return 400 + payload = { + "user_id": "test_user", + "session_id": "test_session", + "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, + } + response = test_app.post("/run", json=payload) + assert response.status_code == 400 + assert "app_name is required" in response.json()["detail"] + + +def test_run_live_websocket_default_app_name( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + monkeypatch, +): + """Test that /run_live websocket endpoint resolves app_name using ADK_DEFAULT_APP_NAME.""" + monkeypatch.setenv("ADK_DEFAULT_APP_NAME", "test_app") + + test_app = _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + ) + + async def setup_session(): + await mock_session_service.create_session( + app_name="test_app", + user_id="user", + session_id="session", + state={}, + ) + + asyncio.run(setup_session()) + + url = "/run_live?user_id=user&session_id=session&modalities=AUDIO" + + with test_app.websocket_connect(url) as ws: + data = ws.receive_json() + assert data["author"] == "dummy agent" + + +def test_run_live_websocket_missing_app_name_raises_error( + test_app, monkeypatch +): + """Test that /run_live websocket connection fails when app_name and ADK_DEFAULT_APP_NAME are both missing.""" + from fastapi.websockets import WebSocketDisconnect + + monkeypatch.delenv("ADK_DEFAULT_APP_NAME", raising=False) + + url = "/run_live?user_id=user&session_id=session&modalities=AUDIO" + + with pytest.raises(WebSocketDisconnect) as exc_info: + with test_app.websocket_connect(url) as ws: + ws.receive_json() + assert exc_info.value.code == 1008 + + +def test_is_single_agent_directory(tmp_path): + """Verify that is_single_agent_directory only identifies directories with agent.py or root_agent.yaml.""" + from google.adk.cli.utils.agent_loader import is_single_agent_directory + + # Directory with agent.py (should be identified as agent) + agent_py_dir = tmp_path / "agent_py_dir" + agent_py_dir.mkdir() + (agent_py_dir / "agent.py").write_text("root_agent = 'dummy'") + assert is_single_agent_directory(str(agent_py_dir)) is True + + # Directory with root_agent.yaml (should be identified as agent) + yaml_dir = tmp_path / "yaml_dir" + yaml_dir.mkdir() + (yaml_dir / "root_agent.yaml").write_text("root_agent: dummy") + assert is_single_agent_directory(str(yaml_dir)) is True + + # Normal directory or standard package with __init__.py only (should NOT be identified as agent) + normal_pkg = tmp_path / "normal_pkg" + normal_pkg.mkdir() + (normal_pkg / "__init__.py").write_text( + "from .app import App\nimport something" + ) + assert is_single_agent_directory(str(normal_pkg)) is False + + +def test_agent_loader_single_agent_mode(tmp_path): + """Verify that AgentLoader automatically detects and configures single agent mode.""" + agent_folder = tmp_path / "my_test_agent" + agent_folder.mkdir() + (agent_folder / "agent.py").write_text("root_agent = 'dummy'") + + loader = fast_api_module.AgentLoader(str(agent_folder)) + + assert loader._is_single_agent is True + assert loader._single_agent_name == "my_test_agent" + assert loader.agents_dir == str(tmp_path) + assert loader.list_agents() == ["my_test_agent"] + + +def test_single_agent_mode_detection( + tmp_path, + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Verify that pointing agents_dir to a single agent folder enables single agent mode.""" + agent_folder = tmp_path / "my_only_agent" + agent_folder.mkdir() + (agent_folder / "agent.py").write_text("root_agent = None") + + with ( + patch.object(signal, "signal", autospec=True, return_value=None), + patch.object( + fast_api_module, + "create_session_service_from_options", + autospec=True, + return_value=mock_session_service, + ), + patch.object( + fast_api_module, + "create_artifact_service_from_options", + autospec=True, + return_value=mock_artifact_service, + ), + patch.object( + fast_api_module, + "create_memory_service_from_options", + autospec=True, + return_value=mock_memory_service, + ), + patch.object( + fast_api_module, + "LocalEvalSetsManager", + autospec=True, + return_value=mock_eval_sets_manager, + ), + patch.object( + fast_api_module, + "LocalEvalSetResultsManager", + autospec=True, + return_value=mock_eval_set_results_manager, + ), + ): + app = get_fast_api_app( + agents_dir=str(agent_folder), + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=None, + a2a=False, + host="127.0.0.1", + port=8000, + ) + client = TestClient(app) + + response = client.get("/list-apps") + assert response.status_code == 200 + assert response.json() == ["my_only_agent"] + + +def test_single_agent_mode_sets_default_app( + tmp_path, + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_eval_sets_manager, + mock_eval_set_results_manager, + monkeypatch, +): + """Verify that in single agent mode, the agent is used as default app.""" + # Set environment variable to something else, but single mode should take precedence. + monkeypatch.setenv("ADK_DEFAULT_APP_NAME", "some_other_app") + + agent_folder = tmp_path / "my_only_agent" + agent_folder.mkdir() + (agent_folder / "agent.py").write_text("root_agent = None") + + # Setup session data in the in-memory service + async def setup_session(): + await mock_session_service.create_session( + app_name="my_only_agent", + user_id="test_user", + session_id="test_session", + state={}, + ) + + asyncio.run(setup_session()) + + with ( + patch.object(signal, "signal", autospec=True, return_value=None), + patch.object( + fast_api_module, + "create_session_service_from_options", + autospec=True, + return_value=mock_session_service, + ), + patch.object( + fast_api_module, + "create_artifact_service_from_options", + autospec=True, + return_value=mock_artifact_service, + ), + patch.object( + fast_api_module, + "create_memory_service_from_options", + autospec=True, + return_value=mock_memory_service, + ), + patch.object( + fast_api_module, + "LocalEvalSetsManager", + autospec=True, + return_value=mock_eval_sets_manager, + ), + patch.object( + fast_api_module, + "LocalEvalSetResultsManager", + autospec=True, + return_value=mock_eval_set_results_manager, + ), + ): + app = get_fast_api_app( + agents_dir=str(agent_folder), + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=None, + a2a=False, + host="127.0.0.1", + port=8000, + ) + client = TestClient(app) + + # Accessing /users/{user_id}/sessions/{session_id} should work because of rewrite + response = client.get("/users/test_user/sessions/test_session") + assert response.status_code == 200 + assert response.json()["id"] == "test_session" + + +def test_agent_run_disconnect_aborts_run( + test_app, create_test_session, monkeypatch +): + """Test that /run endpoint aborts agent execution on client disconnect. + + Verifies that when the client connection is dropped during an active agent + run: + 1. The background agent execution generator task is cancelled. + 2. The endpoint returns a clean 499 (Client Closed Request) status code. + """ + import starlette.requests + + info = create_test_session + trigger_disconnect: dict[str, bool] = {"value": False} + was_cancelled: dict[str, bool] = {"value": False} + + async def run_async_mock( + self, + *, + user_id: str, + session_id: str, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, + state_delta: Optional[dict[str, Any]] = None, + run_config: Optional[RunConfig] = None, + ): + del ( + self, + user_id, + session_id, + invocation_id, + new_message, + state_delta, + run_config, + ) + try: + # Yield first pulse event + yield _event_1() + # Simulate connection drop mid-run + trigger_disconnect["value"] = True + # Run a long async operation to allow the monitor to trigger cancellation + await asyncio.sleep(1.0) + yield _event_2() + except asyncio.CancelledError: + was_cancelled["value"] = True + raise + + monkeypatch.setattr(Runner, "run_async", run_async_mock) + + # Monkeypatch starlette.requests.Request.__init__ to inject simulated disconnect + original_init = starlette.requests.Request.__init__ + + def custom_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + original_receive = self._receive + call_count = 0 + + async def mock_receive(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return await original_receive() + + # Subsequent calls block until simulated connection drop is triggered + while not trigger_disconnect["value"]: + await asyncio.sleep(0.01) + return {"type": "http.disconnect"} + + self._receive = mock_receive + self.__dict__["receive"] = mock_receive + + monkeypatch.setattr(starlette.requests.Request, "__init__", custom_init) + + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "new_message": {"role": "user", "parts": [{"text": "Hello agent"}]}, + "streaming": False, + } + + # When standard /run POST request is initiated and mid-run connection drop occurs + response = test_app.post("/run", json=payload) + + # Then the response status should be 499 and the running generator was cancelled + assert response.status_code == 499 + assert was_cancelled["value"] is True + + ################################################# # Gemini Enterprise Tests ################################################# diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index b70c59dcaa..cae443099e 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -35,6 +35,8 @@ from google.adk.cli import cli_tools_click import pytest +import src.google.adk.cli.cli_deploy as cli_deploy + # Helpers class _Recorder: @@ -378,7 +380,7 @@ def mock_subprocess_run(*args, **kwargs): dockerfile_path = tmp_path / "Dockerfile" assert dockerfile_path.is_file() dockerfile_content = dockerfile_path.read_text() - assert "CMD adk web --port=9090" in dockerfile_content + assert "CMD adk api_server --with_ui --port=9090" in dockerfile_content assert "RUN pip install google-adk==1.2.0" in dockerfile_content assert len(run_recorder.calls) == 3, "Expected 3 subprocess calls" diff --git a/tests/unittests/telemetry/test_google_cloud.py b/tests/unittests/telemetry/test_google_cloud.py index c75397449d..7559f9f201 100644 --- a/tests/unittests/telemetry/test_google_cloud.py +++ b/tests/unittests/telemetry/test_google_cloud.py @@ -52,6 +52,18 @@ def test_get_gcp_exporters( "google.auth.default", auth_mock, ) + monkeypatch.setattr( + "google.adk.telemetry.google_cloud._get_gcp_span_exporter", + lambda credentials: mock.MagicMock(), + ) + monkeypatch.setattr( + "google.adk.telemetry.google_cloud._get_gcp_metrics_exporter", + lambda project_id: mock.MagicMock(), + ) + monkeypatch.setattr( + "google.adk.telemetry.google_cloud._get_gcp_logs_exporter", + lambda project_id: mock.MagicMock(), + ) # Act. otel_hooks = get_gcp_exporters( From d1aef6090fe21ee294f6368988bc1d5bd7485348 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Tue, 2 Jun 2026 21:30:26 +0000 Subject: [PATCH 10/18] chore: Update agent_loader --- src/google/adk/cli/utils/agent_loader.py | 48 +++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py index d4bbfc88f6..c1cb2d03e8 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/agent_loader.py @@ -39,6 +39,16 @@ logger = logging.getLogger("google_adk." + __name__) + +def is_single_agent_directory(path: Path | str) -> bool: + """Returns True if the directory contains a single agent configuration or file.""" + p = Path(path).resolve() + return ( + p.joinpath("agent.py").is_file() + or p.joinpath("root_agent.yaml").is_file() + ) + + # Special agents directory for agents with names starting with double underscore SPECIAL_AGENTS_DIR = os.path.join( os.path.dirname(__file__), "..", "built_in_agents" @@ -60,10 +70,36 @@ class AgentLoader(BaseAgentLoader): """ def __init__(self, agents_dir: str): - self.agents_dir = str(Path(agents_dir)) + agents_path = Path(agents_dir).resolve() + is_single_agent = is_single_agent_directory(agents_path) + if is_single_agent: + self._is_single_agent = True + self._single_agent_name = agents_path.name + self.agents_dir = str(agents_path.parent) + else: + self._is_single_agent = False + self._single_agent_name = None + self.agents_dir = str(agents_path) + self._original_sys_path = None self._agent_cache: dict[str, Union[BaseAgent, App]] = {} + @property + def is_single_agent(self) -> bool: + """Returns True if the loader is in single agent mode.""" + return self._is_single_agent + + @property + def single_agent_name(self) -> Optional[str]: + """Returns the name of the agent in single agent mode.""" + return self._single_agent_name + + def _set_single_agent_mode(self, name: str, agents_dir: str) -> None: + """Internal method to force single agent mode. Use with care.""" + self._is_single_agent = True + self._single_agent_name = name + self.agents_dir = agents_dir + def _load_from_module_or_package( self, agent_name: str ) -> Optional[Union[BaseAgent, App]]: @@ -204,6 +240,13 @@ def _validate_agent_name(self, agent_name: str) -> None: name_to_check = agent_name check_dir = self.agents_dir + if self._is_single_agent and not agent_name.startswith("__"): + if agent_name != self._single_agent_name: + raise ValueError( + f"Agent not found: {agent_name!r}. In single agent mode, only " + f"'{self._single_agent_name}' is accessible." + ) + if not self._VALID_AGENT_NAME_RE.match(name_to_check): raise ValueError( f"Invalid agent name: {agent_name!r}. Agent names must be valid" @@ -368,6 +411,8 @@ def load_agent(self, agent_name: str) -> Union[BaseAgent, App]: @override def list_agents(self) -> list[str]: """Lists all agents available in the agent loader (sorted alphabetically).""" + if self._is_single_agent: + return [self._single_agent_name] base_path = Path.cwd() / self.agents_dir agent_names = [ x @@ -439,3 +484,4 @@ def remove_agent_from_cache(self, agent_name: str): logger.debug("Deleting module %s", key) del sys.modules[key] self._agent_cache.pop(agent_name, None) + \ No newline at end of file From 2b06d4c95ff2d232f41f40242ee73d9c4a353f73 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Wed, 3 Jun 2026 14:37:19 +0000 Subject: [PATCH 11/18] chore: relax the validation logic --- src/google/adk/cli/fast_api.py | 47 ++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 5186a8dd73..394057b5ec 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -746,6 +746,7 @@ async def _get_a2a_runner_async() -> Runner: import inspect import json + from pydantic import ValidationError as _ValidationError from google.adk.agents import Agent import google.auth @@ -842,22 +843,32 @@ async def context_propagation( response_model_exclude_none=True, response_class=JSONResponse, ) - async def query(request: _QueryRequest): + async def query(request: Request): + try: + body = await request.json() + except json.JSONDecodeError as exc: + raise HTTPException( + status_code=400, detail=f"Invalid JSON: {exc}" + ) + try: + parsed = _QueryRequest.model_validate(body) + except _ValidationError as exc: + raise HTTPException(status_code=400, detail=exc.errors()) if not adk_app._tmpl_attrs.get("runner"): adk_app._tmpl_attrs["runner"] = await adk_web_server.get_runner_async( app_name=gemini_enterprise_app_name ) - if request.class_method is None: + if parsed.class_method is None: raise HTTPException( status_code=400, detail="class_method cannot be None" ) - if request.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS: + if parsed.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS: raise HTTPException( status_code=400, - detail=f"class_method {request.class_method} is not allowed", + detail=f"class_method {parsed.class_method} is not allowed", ) - method = getattr(adk_app, request.class_method) - output = await _invoke_callable_or_raise(method, request.input or {}) + method = getattr(adk_app, parsed.class_method) + output = await _invoke_callable_or_raise(method, parsed.input or {}) try: json_serialized_content = jsonable_encoder({"output": output}) @@ -865,7 +876,7 @@ async def query(request: _QueryRequest): logging.exception( "FastAPI could not JSON-encode the response from invocation method" " %s. Error: %s. Invocation method's original response: %r", - request.class_method, + parsed.class_method, encoding_error, output, ) @@ -877,22 +888,32 @@ async def query(request: _QueryRequest): response_model_exclude_none=True, response_class=StreamingResponse, ) - async def stream_query(request: _QueryRequest): + async def stream_query(request: Request): + try: + body = await request.json() + except json.JSONDecodeError as exc: + raise HTTPException( + status_code=400, detail=f"Invalid JSON: {exc}" + ) + try: + parsed = _QueryRequest.model_validate(body) + except _ValidationError as exc: + raise HTTPException(status_code=400, detail=exc.errors()) if not adk_app._tmpl_attrs.get("runner"): adk_app._tmpl_attrs["runner"] = await adk_web_server.get_runner_async( app_name=gemini_enterprise_app_name ) - if request.class_method is None: + if parsed.class_method is None: raise HTTPException( status_code=400, detail="class_method cannot be None" ) - if request.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS: + if parsed.class_method not in _ALLOWED_AGENT_ENGINE_CLASS_METHODS: raise HTTPException( status_code=400, - detail=f"class_method {request.class_method} is not allowed", + detail=f"class_method {parsed.class_method} is not allowed", ) - method = getattr(adk_app, request.class_method) - output = await _invoke_callable_or_raise(method, request.input or {}) + method = getattr(adk_app, parsed.class_method) + output = await _invoke_callable_or_raise(method, parsed.input or {}) if inspect.isgenerator(output): From a2d2e39e527dcb8ce4c307ede37e97c2d842dc18 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Wed, 3 Jun 2026 18:24:04 +0000 Subject: [PATCH 12/18] chore: support deploying with express mode api key --- src/google/adk/cli/cli_deploy.py | 34 ++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 4927be385d..c600a12c91 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -838,7 +838,7 @@ def to_agent_engine( artifact_service_uri: Optional[str] = None, adk_version: Optional[str] = None, ): - """Deploys an agent to Vertex AI Agent Engine. + """Deploys an agent to Agent Platform Runtime. `agent_folder` should contain the following files: @@ -1104,19 +1104,33 @@ def to_agent_engine( from ..utils._google_client_headers import get_tracking_headers - if not project or not region: - click.echo('No project/region provided. Starting onboarding flow...') + if not (api_key or project or region): + click.echo( + 'No apikey/project/region provided. Starting onboarding flow...' + ) auth_info = _onboarding.handle_login_with_google() project = auth_info.project_id region = auth_info.region - click.echo('Initializing Vertex AI...') - client = vertexai.Client( - project=project, - location=region, - http_options={'headers': get_tracking_headers()}, - ) - click.echo('Vertex AI initialized.') + click.echo('Initializing Agent Platform client...') + if project and region: + client = vertexai.Client( + project=project, + location=region, + http_options={'headers': get_tracking_headers()}, + ) + elif api_key: + client = vertexai.Client( + api_key=api_key, + http_options={'headers': get_tracking_headers()}, + ) + else: + click.echo( + 'Failed to initialize Agent Platform client. Please provide an API' + 'key or project and region.' + ) + return + click.echo('Agent Platform client initialized.') if skip_agent_import_validation: warnings.warn( From e34ed0751cf098d7961e0f9ff97867ca0db38dfc Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Wed, 3 Jun 2026 19:32:22 +0000 Subject: [PATCH 13/18] fix: api_key name --- src/google/adk/cli/fast_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 394057b5ec..ef43e86b6b 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -763,7 +763,7 @@ async def _get_a2a_runner_async() -> Runner: ) adk_app._tmpl_attrs["project"] = None adk_app._tmpl_attrs["location"] = None - adk_app._tmpl_attrs["api_key"] = api_key + adk_app._tmpl_attrs["express_mode_api_key"] = api_key else: _, project_id = google.auth.default() location = os.environ.get( @@ -777,7 +777,7 @@ async def _get_a2a_runner_async() -> Runner: ) adk_app._tmpl_attrs["project"] = project_id adk_app._tmpl_attrs["location"] = location - adk_app._tmpl_attrs["api_key"] = None + adk_app._tmpl_attrs["express_mode_api_key"] = None adk_app._tmpl_attrs["runner"] = None adk_app._tmpl_attrs["app_name"] = gemini_enterprise_app_name adk_app._tmpl_attrs["session_service"] = session_service From d95271bcd710ad5a4566b199d37ef9b58187dd20 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Wed, 3 Jun 2026 21:14:55 +0000 Subject: [PATCH 14/18] chore: rebrand to agent platform --- src/google/adk/cli/cli_deploy.py | 85 ++++++++----------- src/google/adk/cli/fast_api.py | 12 +-- src/google/adk/telemetry/google_cloud.py | 6 +- tests/unittests/cli/utils/test_cli_deploy.py | 2 +- .../unittests/telemetry/test_google_cloud.py | 2 +- 5 files changed, 44 insertions(+), 63 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index c600a12c91..d685e4199b 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -41,7 +41,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: - """Ensures staged requirements include Agent Engine dependencies.""" + """Ensures staged requirements include Agent Platform dependencies.""" if not os.path.exists(requirements_txt_path): raise FileNotFoundError( f'requirements.txt not found at: {requirements_txt_path}' @@ -71,11 +71,6 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: FROM python:3.11-slim WORKDIR /app -RUN apt-get update && \ - apt-get upgrade -y && \ - apt-get install -y git && \ - apt -y autoremove - # Create a non-root user RUN adduser --disabled-password --gecos "" myuser @@ -92,7 +87,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: # Set up environment variables - End # Install ADK - Start -# RUN pip install google-adk=={adk_version} +RUN pip install google-adk=={adk_version} # Install ADK - End # Copy agent - Start @@ -468,7 +463,7 @@ def _validate_agent_import( This pre-deployment validation catches common issues like missing dependencies or import errors in custom BaseLlm implementations before - the agent is deployed to Agent Engine. This provides clearer error + the agent is deployed to Agent Platform. This provides clearer error messages and prevents deployments that would fail at runtime. Args: @@ -838,7 +833,7 @@ def to_agent_engine( artifact_service_uri: Optional[str] = None, adk_version: Optional[str] = None, ): - """Deploys an agent to Agent Platform Runtime. + """Deploys an agent to Gemini Enterprise Agent Platform. `agent_folder` should contain the following files: @@ -851,7 +846,7 @@ def to_agent_engine( Args: agent_folder (str): The folder (absolute path) containing the agent source code. - temp_folder (str): The temp folder for the generated Agent Engine source + temp_folder (str): The temp folder for the generated Agent Platform source files. It will be replaced with the generated files if it already exists. adk_app (str): Deprecated. This argument is no longer required or used. staging_bucket (str): Deprecated. This argument is no longer required or @@ -865,8 +860,8 @@ def to_agent_engine( will be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. adk_app_object (str): Deprecated. This argument is no longer required or used. - agent_engine_id (str): Optional. The ID of the Agent Engine instance to - update. If not specified, a new Agent Engine instance will be created. + agent_engine_id (str): Optional. The ID of the Agent Runtime instance to + update. If not specified, a new Agent Runtime instance will be created. absolutize_imports (bool): Deprecated. This argument is no longer required or used. project (str): Optional. Google Cloud project id for the deployed agent. If @@ -875,15 +870,15 @@ def to_agent_engine( region (str): Optional. Google Cloud region for the deployed agent. If not specified, the region from the `GOOGLE_CLOUD_LOCATION` environment variable will be used. It will be ignored if `api_key` is specified. - display_name (str): Optional. The display name of the Agent Engine. - description (str): Optional. The description of the Agent Engine. + display_name (str): Optional. The display name of the Agent Runtime. + description (str): Optional. The description of the Agent Runtime. requirements_file (str): Deprecated. This argument is no longer required or used. env_file (str): Optional. The filepath to the `.env` file for environment variables. If not specified, the `.env` file in the `agent_folder` will be used. The values of `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION` will be overridden by `project` and `region` if they are specified. - agent_engine_config_file (str): The filepath to the agent engine config file + agent_engine_config_file (str): The filepath to the agent platform config file to use. If not specified, the `.agent_engine_config.json` file in the `agent_folder` will be used. skip_agent_import_validation (bool): Deprecated. This argument is no longer @@ -898,7 +893,7 @@ def to_agent_engine( specified, the session service will be deployed to the same parent resource as the runtime. artifact_service_uri (str): Optional. The URI of the artifact service. - adk_version (str): Optional. The ADK version to use in Agent Engine + adk_version (str): Optional. The ADK version to use in Agent Platform deployment. If not specified, the version in the dev environment will be used. """ @@ -934,7 +929,7 @@ def to_agent_engine( did_change_cwd = False if parent_folder != original_cwd: click.echo( - 'Agent Engine deployment uses relative paths; temporarily switching ' + 'Agent Runtime deployment uses relative paths; temporarily switching ' f'working directory to: {parent_folder}' ) os.chdir(parent_folder) @@ -975,29 +970,33 @@ def to_agent_engine( agent_engine_config_file ): raise click.ClickException( - 'Agent engine config file not found: ' + 'Agent Platform config file not found: ' f'{parent_folder}/{agent_engine_config_file}' ) if not agent_engine_config_file: - # Attempt to read the agent engine config from .agent_engine_config.json in the dir (if any). + # Attempt to read the agent platform config from .agent_engine_config.json + # in the dir (if any). agent_engine_config_file = os.path.join( agent_folder, '.agent_engine_config.json' ) if os.path.exists(agent_engine_config_file): - click.echo(f'Reading agent engine config from {agent_engine_config_file}') + click.echo( + f'Reading agent platform config from {agent_engine_config_file}' + ) with open(agent_engine_config_file, 'r') as f: agent_config = json.load(f) if display_name: if 'display_name' in agent_config: click.echo( - 'Overriding display_name in agent engine config with' + 'Overriding display_name in agent platform config with' f' {display_name}' ) agent_config['display_name'] = display_name if description: if 'description' in agent_config: click.echo( - f'Overriding description in agent engine config with {description}' + 'Overriding description in agent platform config with' + f' {description}' ) agent_config['description'] = description @@ -1094,7 +1093,7 @@ def to_agent_engine( if env_vars: if 'env_vars' in agent_config: click.echo( - f'Overriding env_vars in agent engine config with {env_vars}' + f'Overriding env_vars in agent platform config with {env_vars}' ) agent_config['env_vars'] = env_vars # Set env_vars in agent_config to None if it is not set. @@ -1104,32 +1103,18 @@ def to_agent_engine( from ..utils._google_client_headers import get_tracking_headers - if not (api_key or project or region): - click.echo( - 'No apikey/project/region provided. Starting onboarding flow...' - ) + if not project or not region: + click.echo('No project/region provided. Starting onboarding flow...') auth_info = _onboarding.handle_login_with_google() project = auth_info.project_id region = auth_info.region click.echo('Initializing Agent Platform client...') - if project and region: - client = vertexai.Client( - project=project, - location=region, - http_options={'headers': get_tracking_headers()}, - ) - elif api_key: - client = vertexai.Client( - api_key=api_key, - http_options={'headers': get_tracking_headers()}, - ) - else: - click.echo( - 'Failed to initialize Agent Platform client. Please provide an API' - 'key or project and region.' - ) - return + client = vertexai.Client( + project=project, + location=region, + http_options={'headers': get_tracking_headers()}, + ) click.echo('Agent Platform client initialized.') if skip_agent_import_validation: @@ -1187,7 +1172,7 @@ def create_dockerfile_for_agent_engine(resource_name: str): DeprecationWarning, stacklevel=2, ) - click.echo('Deploying to agent engine...') + click.echo('Deploying to Agent Platform...') agent_config['source_packages'] = [f'agents/{app_name}', 'Dockerfile'] agent_config['image_spec'] = {} # Use the Dockerfile agent_config['class_methods'] = _AGENT_ENGINE_CLASS_METHODS @@ -1197,7 +1182,7 @@ def create_dockerfile_for_agent_engine(resource_name: str): if not resource_name: agent_engine = client.agent_engines.create() resource_name = agent_engine.api_resource.name - click.secho(f'Created a new agent engine: {resource_name}', fg='green') + click.secho(f'Created a new instance: {resource_name}', fg='green') elif project and region and not resource_name.startswith('projects/'): resource_name = f'projects/{project}/locations/{region}/reasoningEngines/{agent_engine_id}' click.echo('Creating Dockerfile...') @@ -1205,13 +1190,13 @@ def create_dockerfile_for_agent_engine(resource_name: str): click.echo(f'Dockerfile created at {os.getcwd()}/Dockerfile.') try: client.agent_engines.update(name=resource_name, config=agent_config) - click.secho(f'Deployed to agent engine: {resource_name}', fg='green') + click.secho(f'Deployed to Agent Platform: {resource_name}', fg='green') except Exception as e: - click.secho(f'Failed to deploy to agent engine: {e}', fg='red') - # Only delete the agent engine if it was newly created in this function. + click.secho(f'Failed to deploy to Agent Platform: {e}', fg='red') + # Only delete the instance if it was newly created in this function. if agent_engine_id is None: client.agent_engines.delete(name=resource_name) - click.secho(f'Cleaned up the agent engine: {resource_name}', fg='green') + click.secho(f'Cleaned up the instance: {resource_name}', fg='green') raise e _print_agent_engine_url(resource_name) finally: diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index ef43e86b6b..9d66a07b92 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -54,8 +54,8 @@ from ..telemetry._agent_engine import get_propagated_context from ..telemetry._agent_engine import TopSpanProcessor from .api_server import ApiServer -from .dev_server import DevServer from .cli_deploy import _AGENT_ENGINE_CLASS_METHODS +from .dev_server import DevServer from .service_registry import load_services_module from .utils import envs from .utils.agent_change_handler import AgentChangeEventHandler @@ -746,10 +746,10 @@ async def _get_a2a_runner_async() -> Runner: import inspect import json - from pydantic import ValidationError as _ValidationError from google.adk.agents import Agent import google.auth + from pydantic import ValidationError as _ValidationError from vertexai import agent_engines # The tmp agent will be replaced by the adk server's runner and services. @@ -847,9 +847,7 @@ async def query(request: Request): try: body = await request.json() except json.JSONDecodeError as exc: - raise HTTPException( - status_code=400, detail=f"Invalid JSON: {exc}" - ) + raise HTTPException(status_code=400, detail=f"Invalid JSON: {exc}") try: parsed = _QueryRequest.model_validate(body) except _ValidationError as exc: @@ -892,9 +890,7 @@ async def stream_query(request: Request): try: body = await request.json() except json.JSONDecodeError as exc: - raise HTTPException( - status_code=400, detail=f"Invalid JSON: {exc}" - ) + raise HTTPException(status_code=400, detail=f"Invalid JSON: {exc}") try: parsed = _QueryRequest.model_validate(body) except _ValidationError as exc: diff --git a/src/google/adk/telemetry/google_cloud.py b/src/google/adk/telemetry/google_cloud.py index c936105cc7..b7500906a4 100644 --- a/src/google/adk/telemetry/google_cloud.py +++ b/src/google/adk/telemetry/google_cloud.py @@ -252,9 +252,9 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource: resource_attributes["cloud.resource.id"] = cloud_resource_id if agent_engine_id: - resource = Resource.create( - attributes=resource_attributes - ).merge(OTELResourceDetector().detect()) + resource = Resource.create(attributes=resource_attributes).merge( + OTELResourceDetector().detect() + ) return resource resource = Resource( diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index cae443099e..da2e4f6163 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -329,7 +329,7 @@ def test_to_agent_engine_raises_when_explicit_config_file_missing( adk_version="1.2.0", ) - assert "Agent engine config file not found" in str(exc_info.value) + assert "Agent Platform config file not found" in str(exc_info.value) assert expected_abs in str(exc_info.value) diff --git a/tests/unittests/telemetry/test_google_cloud.py b/tests/unittests/telemetry/test_google_cloud.py index 7559f9f201..8b57ac9dd1 100644 --- a/tests/unittests/telemetry/test_google_cloud.py +++ b/tests/unittests/telemetry/test_google_cloud.py @@ -62,7 +62,7 @@ def test_get_gcp_exporters( ) monkeypatch.setattr( "google.adk.telemetry.google_cloud._get_gcp_logs_exporter", - lambda project_id: mock.MagicMock(), + lambda project_id, credentials: mock.MagicMock(), ) # Act. From 69b32611a42324b542ed213d2a1393dba0ccfaa4 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Wed, 3 Jun 2026 21:20:10 +0000 Subject: [PATCH 15/18] chore: final touches for e2e testing --- src/google/adk/cli/cli_deploy.py | 38 ++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index d685e4199b..03904ffaa3 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -71,6 +71,11 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: FROM python:3.11-slim WORKDIR /app +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y git && \ + apt -y autoremove + # Create a non-root user RUN adduser --disabled-password --gecos "" myuser @@ -87,7 +92,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: # Set up environment variables - End # Install ADK - Start -RUN pip install google-adk=={adk_version} +# RUN pip install google-adk=={adk_version} # Install ADK - End # Copy agent - Start @@ -1103,19 +1108,34 @@ def to_agent_engine( from ..utils._google_client_headers import get_tracking_headers - if not project or not region: - click.echo('No project/region provided. Starting onboarding flow...') + if not (api_key or project or region): + click.echo( + 'No api_key/project/region provided. Starting onboarding flow...' + ) auth_info = _onboarding.handle_login_with_google() project = auth_info.project_id region = auth_info.region click.echo('Initializing Agent Platform client...') - client = vertexai.Client( - project=project, - location=region, - http_options={'headers': get_tracking_headers()}, - ) - click.echo('Agent Platform client initialized.') + if project and region: + client = vertexai.Client( + project=project, + location=region, + http_options={'headers': get_tracking_headers()}, + ) + click.echo('Agent Platform client initialized with project and region.') + elif api_key: + client = vertexai.Client( + api_key=api_key, + http_options={'headers': get_tracking_headers()}, + ) + click.echo('Agent Platform client initialized with ExpressMode API Key.') + else: + click.echo( + 'Failed to initialize Agent Platform client. Please provide an API' + 'key or project and region.' + ) + return if skip_agent_import_validation: warnings.warn( From af9ac74a62258e2994525740b67859c1e146e804 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Wed, 3 Jun 2026 22:07:35 +0000 Subject: [PATCH 16/18] chore: final touches for review --- src/google/adk/cli/cli_deploy.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 03904ffaa3..3301ca4431 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -71,11 +71,6 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: FROM python:3.11-slim WORKDIR /app -RUN apt-get update && \ - apt-get upgrade -y && \ - apt-get install -y git && \ - apt -y autoremove - # Create a non-root user RUN adduser --disabled-password --gecos "" myuser @@ -92,7 +87,7 @@ def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None: # Set up environment variables - End # Install ADK - Start -# RUN pip install google-adk=={adk_version} +RUN pip install google-adk=={adk_version} # Install ADK - End # Copy agent - Start @@ -942,7 +937,6 @@ def to_agent_engine( tmp_app_name = app_name + '_tmp' + datetime.now().strftime('%Y%m%d_%H%M%S') temp_folder = temp_folder or tmp_app_name agent_src_path = os.path.join(parent_folder, temp_folder, 'agents', app_name) - click.echo(f'Staging all files in: {agent_src_path}') # remove agent_src_path if it exists if os.path.exists(agent_src_path): click.echo('Removing existing files') From c65a2ae25bc4d0ee761685efd6d8711f0647cf13 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Wed, 3 Jun 2026 22:10:51 +0000 Subject: [PATCH 17/18] chore: revert all changes to adk web server --- src/google/adk/cli/adk_web_server.py | 2228 +------------------------- 1 file changed, 10 insertions(+), 2218 deletions(-) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 5b4cdc2233..b567ce949d 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -14,2230 +14,22 @@ from __future__ import annotations -import asyncio -from contextlib import asynccontextmanager -import importlib -import json import logging -import os -import re -import sys -import time -import traceback -import typing -from typing import Any -from typing import Callable -from typing import List -from typing import Literal -from typing import Optional -from fastapi import FastAPI -from fastapi import HTTPException -from fastapi import Query -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import RedirectResponse -from fastapi.responses import StreamingResponse -from fastapi.staticfiles import StaticFiles -from fastapi.websockets import WebSocket -from fastapi.websockets import WebSocketDisconnect -from google.genai import types -import graphviz -from opentelemetry import trace -import opentelemetry.sdk.environment_variables as otel_env -from opentelemetry.sdk.trace import export as export_lib -from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.sdk.trace import SpanProcessor -from opentelemetry.sdk.trace import TracerProvider -from pydantic import Field -from pydantic import ValidationError -from starlette.types import Lifespan from typing_extensions import deprecated -from typing_extensions import override -from watchdog.observers import Observer -import yaml -from . import agent_graph -from ..agents.base_agent import BaseAgent -from ..agents.live_request_queue import LiveRequest -from ..agents.live_request_queue import LiveRequestQueue -from ..agents.llm_agent import LlmAgent -from ..agents.run_config import RunConfig -from ..agents.run_config import StreamingMode -from ..apps.app import App -from ..artifacts.base_artifact_service import ArtifactVersion -from ..artifacts.base_artifact_service import BaseArtifactService -from ..auth.credential_service.base_credential_service import BaseCredentialService -from ..errors.already_exists_error import AlreadyExistsError -from ..errors.input_validation_error import InputValidationError -from ..errors.not_found_error import NotFoundError -from ..errors.session_not_found_error import SessionNotFoundError -from ..evaluation.base_eval_service import InferenceConfig -from ..evaluation.base_eval_service import InferenceRequest -from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE -from ..evaluation.eval_case import EvalCase -from ..evaluation.eval_case import SessionInput -from ..evaluation.eval_metrics import EvalMetric -from ..evaluation.eval_metrics import EvalMetricResult -from ..evaluation.eval_metrics import EvalMetricResultPerInvocation -from ..evaluation.eval_metrics import EvalStatus -from ..evaluation.eval_metrics import MetricInfo -from ..evaluation.eval_result import EvalSetResult -from ..evaluation.eval_set import EvalSet -from ..evaluation.eval_set_results_manager import EvalSetResultsManager -from ..evaluation.eval_sets_manager import EvalSetsManager -from ..events.event import Event -from ..memory.base_memory_service import BaseMemoryService -from ..plugins.base_plugin import BasePlugin -from ..runners import Runner -from ..sessions.base_session_service import BaseSessionService -from ..sessions.session import Session -from ..utils.agent_info import AgentInfo -from ..utils.agent_info import get_agents_dict -from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental -from ..version import __version__ -from .cli_eval import EVAL_SESSION_ID_PREFIX -from .utils import cleanup -from .utils import common -from .utils import envs -from .utils import evals -from .utils.base_agent_loader import BaseAgentLoader -from .utils.shared_value import SharedValue -from .utils.state import create_empty_state +from .api_server import _parse_cors_origins +from .api_server import RunAgentRequest +from .dev_server import DevServer logger = logging.getLogger("google_adk." + __name__) -_EVAL_SET_FILE_EXTENSION = ".evalset.json" -TAG_DEBUG = "Debug" -TAG_EVALUATION = "Evaluation" +@deprecated( + "AdkWebServer is deprecated and has been refactored into ApiServer and" + " DevServer. Use DevServer instead." +) +class AdkWebServer(DevServer): + """Deprecated wrapper class around DevServer for backward compatibility.""" -_REGEX_PREFIX = "regex:" - - -def _parse_cors_origins( - allow_origins: list[str], -) -> tuple[list[str], Optional[str]]: - """Parse allow_origins into literal origins and a combined regex pattern. - - Args: - allow_origins: List of origin strings. Entries prefixed with 'regex:' are - treated as regex patterns; all others are treated as literal origins. - - Returns: - A tuple of (literal_origins, combined_regex) where combined_regex is None - if no regex patterns were provided, or a single pattern joining all regex - patterns with '|'. - """ - literal_origins = [] - regex_patterns = [] - for origin in allow_origins: - if origin.startswith(_REGEX_PREFIX): - pattern = origin[len(_REGEX_PREFIX) :] - if pattern: - regex_patterns.append(pattern) - else: - literal_origins.append(origin) - - combined_regex = "|".join(regex_patterns) if regex_patterns else None - return literal_origins, combined_regex - - -def _is_origin_allowed( - origin: str, - allowed_literal_origins: list[str], - allowed_origin_regex: Optional[re.Pattern[str]], -) -> bool: - """Check whether the given origin matches the allowed origins.""" - if "*" in allowed_literal_origins: - return True - if origin in allowed_literal_origins: - return True - if allowed_origin_regex is not None: - return allowed_origin_regex.fullmatch(origin) is not None - return False - - -def _normalize_origin_scheme(scheme: str) -> str: - """Normalize request schemes to the browser Origin scheme space.""" - if scheme == "ws": - return "http" - if scheme == "wss": - return "https" - return scheme - - -def _strip_optional_quotes(value: str) -> str: - """Strip a single pair of wrapping quotes from a header value.""" - if len(value) >= 2 and value[0] == '"' and value[-1] == '"': - return value[1:-1] - return value - - -def _get_scope_header( - scope: dict[str, Any], header_name: bytes -) -> Optional[str]: - """Return the first matching header value from an ASGI scope.""" - for candidate_name, candidate_value in scope.get("headers", []): - if candidate_name == header_name: - return candidate_value.decode("latin-1").split(",", 1)[0].strip() - return None - - -def _get_request_origin(scope: dict[str, Any]) -> Optional[str]: - """Compute the effective origin for the current HTTP/WebSocket request.""" - forwarded = _get_scope_header(scope, b"forwarded") - if forwarded is not None: - proto = None - host = None - for element in forwarded.split(",", 1)[0].split(";"): - if "=" not in element: - continue - name, value = element.split("=", 1) - if name.strip().lower() == "proto": - proto = _strip_optional_quotes(value.strip()) - elif name.strip().lower() == "host": - host = _strip_optional_quotes(value.strip()) - if proto is not None and host is not None: - return f"{_normalize_origin_scheme(proto)}://{host}" - - host = _get_scope_header(scope, b"x-forwarded-host") - if host is None: - host = _get_scope_header(scope, b"host") - if host is None: - return None - - proto = _get_scope_header(scope, b"x-forwarded-proto") - if proto is None: - proto = scope.get("scheme", "http") - return f"{_normalize_origin_scheme(proto)}://{host}" - - -def _is_request_origin_allowed( - origin: str, - scope: dict[str, Any], - allowed_literal_origins: list[str], - allowed_origin_regex: Optional[re.Pattern[str]], - has_configured_allowed_origins: bool, -) -> bool: - """Validate an Origin header against explicit config or same-origin.""" - if has_configured_allowed_origins and _is_origin_allowed( - origin, allowed_literal_origins, allowed_origin_regex - ): - return True - - request_origin = _get_request_origin(scope) - if request_origin is None: - return False - return origin == request_origin - - -_SAFE_HTTP_METHODS = frozenset({"GET", "HEAD", "OPTIONS"}) - - -class _OriginCheckMiddleware: - """ASGI middleware that blocks cross-origin state-changing requests.""" - - def __init__( - self, - app: Any, - has_configured_allowed_origins: bool, - allowed_origins: list[str], - allowed_origin_regex: Optional[re.Pattern[str]], - ) -> None: - self._app = app - self._has_configured_allowed_origins = has_configured_allowed_origins - self._allowed_origins = allowed_origins - self._allowed_origin_regex = allowed_origin_regex - - async def __call__( - self, - scope: dict[str, Any], - receive: Any, - send: Any, - ) -> None: - if scope["type"] != "http": - await self._app(scope, receive, send) - return - - method = scope.get("method", "GET") - if method in _SAFE_HTTP_METHODS: - await self._app(scope, receive, send) - return - - origin = _get_scope_header(scope, b"origin") - if origin is None: - await self._app(scope, receive, send) - return - - if _is_request_origin_allowed( - origin, - scope, - self._allowed_origins, - self._allowed_origin_regex, - self._has_configured_allowed_origins, - ): - await self._app(scope, receive, send) - return - - response_body = b"Forbidden: origin not allowed" - await send({ - "type": "http.response.start", - "status": 403, - "headers": [ - (b"content-type", b"text/plain"), - (b"content-length", str(len(response_body)).encode()), - ], - }) - await send({ - "type": "http.response.body", - "body": response_body, - }) - - -class ApiServerSpanExporter(export_lib.SpanExporter): - - def __init__(self, trace_dict): - self.trace_dict = trace_dict - - def export( - self, spans: typing.Sequence[ReadableSpan] - ) -> export_lib.SpanExportResult: - for span in spans: - if ( - span.name == "call_llm" - or span.name == "send_data" - or span.name.startswith("execute_tool") - ): - attributes = dict(span.attributes) - attributes["trace_id"] = span.get_span_context().trace_id - attributes["span_id"] = span.get_span_context().span_id - if attributes.get("gcp.vertex.agent.event_id", None): - self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes - return export_lib.SpanExportResult.SUCCESS - - def force_flush(self, timeout_millis: int = 30000) -> bool: - return True - - -class InMemoryExporter(export_lib.SpanExporter): - - def __init__(self, trace_dict): - super().__init__() - self._spans = [] - self.trace_dict = trace_dict - - @override - def export( - self, spans: typing.Sequence[ReadableSpan] - ) -> export_lib.SpanExportResult: - for span in spans: - trace_id = span.context.trace_id - if span.name == "call_llm": - attributes = dict(span.attributes) - session_id = attributes.get("gcp.vertex.agent.session_id", None) - if session_id: - if session_id not in self.trace_dict: - self.trace_dict[session_id] = [trace_id] - else: - self.trace_dict[session_id] += [trace_id] - self._spans.extend(spans) - return export_lib.SpanExportResult.SUCCESS - - @override - def force_flush(self, timeout_millis: int = 30000) -> bool: - return True - - def get_finished_spans(self, session_id: str): - trace_ids = self.trace_dict.get(session_id, None) - if trace_ids is None or not trace_ids: - return [] - return [x for x in self._spans if x.context.trace_id in trace_ids] - - def clear(self): - self._spans.clear() - - -class RunAgentRequest(common.BaseModel): - app_name: str - user_id: str - session_id: str - new_message: Optional[types.Content] = None - streaming: bool = False - state_delta: Optional[dict[str, Any]] = None - # for long-running function resume requests (e.g., OAuth callback) - function_call_event_id: Optional[str] = None - # for resume long-running functions - invocation_id: Optional[str] = None - custom_metadata: Optional[dict[str, Any]] = None - - -class CreateSessionRequest(common.BaseModel): - session_id: Optional[str] = Field( - default=None, - description=( - "The ID of the session to create. If not provided, a random session" - " ID will be generated." - ), - ) - state: Optional[dict[str, Any]] = Field( - default=None, description="The initial state of the session." - ) - events: Optional[list[Event]] = Field( - default=None, - description="A list of events to initialize the session with.", - ) - - -class SaveArtifactRequest(common.BaseModel): - """Request payload for saving a new artifact.""" - - filename: str = Field(description="Artifact filename.") - artifact: types.Part = Field( - description="Artifact payload encoded as google.genai.types.Part." - ) - custom_metadata: Optional[dict[str, Any]] = Field( - default=None, - description="Optional metadata to associate with the artifact version.", - ) - - -class AddSessionToEvalSetRequest(common.BaseModel): - eval_id: str - session_id: str - user_id: str - - -class RunEvalRequest(common.BaseModel): - eval_ids: list[str] = Field( - deprecated=True, - default_factory=list, - description="This field is deprecated, use eval_case_ids instead.", - ) - eval_case_ids: list[str] = Field( - default_factory=list, - description=( - "List of eval case ids to evaluate. if empty, then all eval cases in" - " the eval set are run." - ), - ) - eval_metrics: list[EvalMetric] - - -class UpdateMemoryRequest(common.BaseModel): - """Request to add a session to the memory service.""" - - session_id: str - """The ID of the session to add to memory.""" - - -class UpdateSessionRequest(common.BaseModel): - """Request to update session state without running the agent.""" - - state_delta: dict[str, Any] - """The state changes to apply to the session.""" - - -class RunEvalResult(common.BaseModel): - eval_set_file: str - eval_set_id: str - eval_id: str - final_eval_status: EvalStatus - eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( - deprecated=True, - default=[], - description=( - "This field is deprecated, use overall_eval_metric_results instead." - ), - ) - overall_eval_metric_results: list[EvalMetricResult] - eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation] - user_id: str - session_id: str - - -class RunEvalResponse(common.BaseModel): - run_eval_results: list[RunEvalResult] - - -class GetEventGraphResult(common.BaseModel): - dot_src: str - - -class CreateEvalSetRequest(common.BaseModel): - eval_set: EvalSet - - -class ListEvalSetsResponse(common.BaseModel): - eval_set_ids: list[str] - - -class EvalResult(EvalSetResult): - """This class has no field intentionally. - - The goal here is to just give a new name to the class to align with the API - endpoint. - """ - - -class ListEvalResultsResponse(common.BaseModel): - eval_result_ids: list[str] - - -class ListMetricsInfoResponse(common.BaseModel): - metrics_info: list[MetricInfo] - - -class AppInfo(common.BaseModel): - name: str - root_agent_name: str - description: str - language: Literal["yaml", "python"] - is_computer_use: bool = False - agents: Optional[dict[str, AgentInfo]] = None - - -class ListAppsResponse(common.BaseModel): - apps: list[AppInfo] - - -def _setup_telemetry( - otel_to_cloud: bool = False, - internal_exporters: Optional[list[SpanProcessor]] = None, -): - # TODO - remove the else branch here once maybe_set_otel_providers is no - # longer experimental. - if otel_to_cloud: - _setup_gcp_telemetry(internal_exporters=internal_exporters) - elif _otel_env_vars_enabled(): - _setup_telemetry_from_env(internal_exporters=internal_exporters) - else: - # Old logic - to be removed when above leaves experimental. - tracer_provider = TracerProvider() - if internal_exporters is not None: - for exporter in internal_exporters: - tracer_provider.add_span_processor(exporter) - trace.set_tracer_provider(tracer_provider=tracer_provider) - - -def _otel_env_vars_enabled() -> bool: - return any([ - os.getenv(endpoint_var) - for endpoint_var in [ - otel_env.OTEL_EXPORTER_OTLP_ENDPOINT, - otel_env.OTEL_EXPORTER_OTLP_TRACES_ENDPOINT, - otel_env.OTEL_EXPORTER_OTLP_METRICS_ENDPOINT, - otel_env.OTEL_EXPORTER_OTLP_LOGS_ENDPOINT, - ] - ]) - - -def _setup_gcp_telemetry( - internal_exporters: list[SpanProcessor] = None, -): - if typing.TYPE_CHECKING: - from ..telemetry.setup import OTelHooks - - otel_hooks_to_add: list[OTelHooks] = [] - - if internal_exporters: - from ..telemetry.setup import OTelHooks - - # Register ADK-specific exporters in trace provider. - otel_hooks_to_add.append(OTelHooks(span_processors=internal_exporters)) - - import google.auth - - from ..telemetry.google_cloud import get_gcp_exporters - from ..telemetry.google_cloud import get_gcp_resource - from ..telemetry.setup import maybe_set_otel_providers - - credentials, project_id = google.auth.default() - - otel_hooks_to_add.append( - get_gcp_exporters( - # TODO - use trace_to_cloud here as well once otel_to_cloud is no - # longer experimental. - enable_cloud_tracing=True, - # TODO - re-enable metrics once errors during shutdown are fixed. - enable_cloud_metrics=False, - enable_cloud_logging=True, - google_auth=(credentials, project_id), - ) - ) - otel_resource = get_gcp_resource(project_id) - - maybe_set_otel_providers( - otel_hooks_to_setup=otel_hooks_to_add, - otel_resource=otel_resource, - ) - _setup_instrumentation_lib_if_installed() - - -def _setup_telemetry_from_env( - internal_exporters: list[SpanProcessor] = None, -): - from ..telemetry.setup import maybe_set_otel_providers - - otel_hooks_to_add = [] - - if internal_exporters: - from ..telemetry.setup import OTelHooks - - # Register ADK-specific exporters in trace provider. - otel_hooks_to_add.append(OTelHooks(span_processors=internal_exporters)) - - maybe_set_otel_providers(otel_hooks_to_setup=otel_hooks_to_add) - _setup_instrumentation_lib_if_installed() - - -def _setup_instrumentation_lib_if_installed(): - # Set instrumentation to enable emitting OTel data from GenAISDK - # Currently the instrumentation lib is in extras dependencies, make sure to - # warn the user if it's not installed. - try: - from opentelemetry.instrumentation.google_genai import GoogleGenAiSdkInstrumentor - - GoogleGenAiSdkInstrumentor().instrument() - except (ImportError, AttributeError): - logger.warning( - "Unable to import GoogleGenAiSdkInstrumentor - some" - " telemetry will be disabled. Make sure to install google-adk[otel-gcp]" - ) - if os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID"): - # Set up HTTPX and gRPC instrumentation for A2A multi-agent observability. - try: - from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor - - HTTPXClientInstrumentor().instrument() - except (ImportError, AttributeError): - logger.warning( - "telemetry enabled but proceeding without HTTPX instrumentation," - " because google-adk[otel-gcp] has not been installed" - ) - try: - from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient - - GrpcInstrumentorClient().instrument() - except (ImportError, AttributeError): - logger.warning( - "telemetry enabled but proceeding without gRPC instrumentation," - " because google-adk[otel-gcp] has not been installed" - ) - - -class AdkWebServer: - """Helper class for setting up and running the ADK web server on FastAPI. - - You construct this class with all the Services required to run ADK agents and - can then call the get_fast_api_app method to get a FastAPI app instance that - can will use your provided service instances, static assets, and agent loader. - If you pass in a web_assets_dir, the static assets will be served under - /dev-ui in addition to the API endpoints created by default. - - You can add additional API endpoints by modifying the FastAPI app - instance returned by get_fast_api_app as this class exposes the agent runners - and most other bits of state retained during the lifetime of the server. - - Attributes: - agent_loader: An instance of BaseAgentLoader for loading agents. - session_service: An instance of BaseSessionService for managing sessions. - memory_service: An instance of BaseMemoryService for managing memory. - artifact_service: An instance of BaseArtifactService for managing - artifacts. - credential_service: An instance of BaseCredentialService for managing - credentials. - eval_sets_manager: An instance of EvalSetsManager for managing evaluation - sets. - eval_set_results_manager: An instance of EvalSetResultsManager for - managing evaluation set results. - agents_dir: Root directory containing subdirs for agents with those - containing resources (e.g. .env files, eval sets, etc.) for the agents. - extra_plugins: A list of fully qualified names of extra plugins to load. - logo_text: Text to display in the logo of the UI. - logo_image_url: URL of an image to display as logo of the UI. - runners_to_clean: Set of runner names marked for cleanup. - current_app_name_ref: A shared reference to the latest ran app name. - runner_dict: A dict of instantiated runners for each app. - """ - - def __init__( - self, - *, - agent_loader: BaseAgentLoader, - session_service: BaseSessionService, - memory_service: BaseMemoryService, - artifact_service: BaseArtifactService, - credential_service: BaseCredentialService, - eval_sets_manager: EvalSetsManager, - eval_set_results_manager: EvalSetResultsManager, - agents_dir: str, - extra_plugins: Optional[list[str]] = None, - logo_text: Optional[str] = None, - logo_image_url: Optional[str] = None, - url_prefix: Optional[str] = None, - auto_create_session: bool = False, - trigger_sources: Optional[list[str]] = None, - ): - self.agent_loader = agent_loader - self.session_service = session_service - self.memory_service = memory_service - self.artifact_service = artifact_service - self.credential_service = credential_service - self.eval_sets_manager = eval_sets_manager - self.eval_set_results_manager = eval_set_results_manager - self.agents_dir = agents_dir - self.extra_plugins = extra_plugins or [] - self.logo_text = logo_text - self.logo_image_url = logo_image_url - # Internal properties we want to allow being modified from callbacks. - self.runners_to_clean: set[str] = set() - self.current_app_name_ref: SharedValue[str] = SharedValue(value="") - self.runner_dict = {} - self.url_prefix = url_prefix - self.auto_create_session = auto_create_session - self.trigger_sources = trigger_sources - - async def get_runner_async(self, app_name: str) -> Runner: - """Returns the cached runner for the given app.""" - # Handle cleanup - if app_name in self.runners_to_clean: - self.runners_to_clean.remove(app_name) - runner = self.runner_dict.pop(app_name, None) - await cleanup.close_runners(list([runner])) - - # Return cached runner if exists - if app_name in self.runner_dict: - return self.runner_dict[app_name] - - # Create new runner - envs.load_dotenv_for_agent(os.path.basename(app_name), self.agents_dir) - agent_or_app = self.agent_loader.load_agent(app_name) - - # Instantiate extra plugins if configured - extra_plugins_instances = self._instantiate_extra_plugins() - - plugins_yaml_path = os.path.join(self.agents_dir, app_name, "plugins.yaml") - bq_analytics_config = None - if os.path.exists(plugins_yaml_path): - with open(plugins_yaml_path, "r", encoding="utf-8") as f: - plugins_config = yaml.safe_load(f) - if plugins_config and isinstance(plugins_config, dict): - bq_analytics_config = plugins_config.get("bigquery_agent_analytics") - - # All YAML agents are treated as visual builder agents. - is_visual_builder_agent = os.path.exists( - os.path.join(self.agents_dir, app_name, "root_agent.yaml") - ) - - if isinstance(agent_or_app, BaseAgent): - plugins = extra_plugins_instances - - # Handle BigQuery Analytics Plugin injection - if bq_analytics_config and all([ - bq_analytics_config.get("project_id"), - bq_analytics_config.get("dataset_id"), - bq_analytics_config.get("dataset_location"), - ]): - from ..plugins.bigquery_agent_analytics_plugin import BigQueryAgentAnalyticsPlugin - - plugins.append( - BigQueryAgentAnalyticsPlugin( - project_id=bq_analytics_config.get("project_id"), - dataset_id=bq_analytics_config.get("dataset_id"), - table_id=bq_analytics_config.get("table_id"), - location=bq_analytics_config.get("dataset_location"), - ) - ) - - agentic_app = App( - name=app_name, - root_agent=agent_or_app, - plugins=plugins, - ) - else: - # Combine existing plugins with extra plugins - agent_or_app.plugins = agent_or_app.plugins + extra_plugins_instances - agentic_app = agent_or_app - - # If the root agent was loaded from YAML, we treat it as being from Visual Builder - if is_visual_builder_agent: - object.__setattr__(agentic_app, "_is_visual_builder_app", True) - - runner = self._create_runner(agentic_app) - self.runner_dict[app_name] = runner - return runner - - def _get_root_agent(self, agent_or_app: BaseAgent | App) -> BaseAgent: - """Extract root agent from either a BaseAgent or App object.""" - if isinstance(agent_or_app, App): - return agent_or_app.root_agent - return agent_or_app - - def _create_runner(self, agentic_app: App) -> Runner: - """Create a runner with common services.""" - return Runner( - app=agentic_app, - artifact_service=self.artifact_service, - session_service=self.session_service, - memory_service=self.memory_service, - credential_service=self.credential_service, - auto_create_session=self.auto_create_session, - ) - - def _instantiate_extra_plugins(self) -> list[BasePlugin]: - """Instantiate extra plugins from the configured list. - - Returns: - List of instantiated BasePlugin objects. - """ - extra_plugins_instances = [] - for qualified_name in self.extra_plugins: - try: - plugin_obj = self._import_plugin_object(qualified_name) - if isinstance(plugin_obj, BasePlugin): - extra_plugins_instances.append(plugin_obj) - elif issubclass(plugin_obj, BasePlugin): - extra_plugins_instances.append(plugin_obj(name=qualified_name)) - except Exception as e: - logger.error("Failed to load plugin %s: %s", qualified_name, e) - return extra_plugins_instances - - def _import_plugin_object(self, qualified_name: str) -> Any: - """Import a plugin object (class or instance) from a fully qualified name. - - Args: - qualified_name: Fully qualified name (e.g., - 'my_package.my_plugin.MyPlugin') - - Returns: - The imported object, which can be either a class or an instance. - - Raises: - ImportError: If the module cannot be imported. - AttributeError: If the object doesn't exist in the module. - """ - module_name, obj_name = qualified_name.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, obj_name) - - def _setup_runtime_config(self, web_assets_dir: str): - """Sets up the runtime config for the web server.""" - # Read existing runtime config file. - runtime_config_path = os.path.join( - web_assets_dir, "assets", "config", "runtime-config.json" - ) - runtime_config = {} - try: - with open(runtime_config_path, "r") as f: - runtime_config = json.load(f) - except FileNotFoundError: - logger.info( - "File not found: %s. A new runtime config file will be created.", - runtime_config_path, - ) - except json.JSONDecodeError: - logger.warning( - "Failed to decode JSON from %s. The file content will be" - " overwritten.", - runtime_config_path, - ) - runtime_config["backendUrl"] = self.url_prefix if self.url_prefix else "" - - # Set custom logo config. - if self.logo_text or self.logo_image_url: - if not self.logo_text or not self.logo_image_url: - raise ValueError( - "Both --logo-text and --logo-image-url must be defined when using" - " logo config." - ) - runtime_config["logo"] = { - "text": self.logo_text, - "imageUrl": self.logo_image_url, - } - elif "logo" in runtime_config: - del runtime_config["logo"] - - # Write the runtime config file. - try: - os.makedirs(os.path.dirname(runtime_config_path), exist_ok=True) - with open(runtime_config_path, "w") as f: - json.dump(runtime_config, f, indent=2) - f.write("\n") - except IOError as e: - logger.error( - "Failed to write runtime config file %s: %s", runtime_config_path, e - ) - - async def _create_session( - self, - *, - app_name: str, - user_id: str, - session_id: Optional[str] = None, - state: Optional[dict[str, Any]] = None, - ) -> Session: - try: - session = await self.session_service.create_session( - app_name=app_name, - user_id=user_id, - state=state, - session_id=session_id, - ) - logger.info("New session created: %s", session.id) - return session - except AlreadyExistsError as e: - raise HTTPException( - status_code=409, detail=f"Session already exists: {session_id}" - ) from e - except Exception as e: - logger.error( - "Internal server error during session creation: %s", e, exc_info=True - ) - raise HTTPException(status_code=500, detail=str(e)) from e - - def get_fast_api_app( - self, - lifespan: Optional[Lifespan[FastAPI]] = None, - allow_origins: Optional[list[str]] = None, - web_assets_dir: Optional[str] = None, - setup_observer: Callable[ - [Observer, "AdkWebServer"], None - ] = lambda o, s: None, - tear_down_observer: Callable[ - [Observer, "AdkWebServer"], None - ] = lambda o, s: None, - register_processors: Callable[[TracerProvider], None] = lambda o: None, - otel_to_cloud: bool = False, - with_ui: bool = False, - ): - """Creates a FastAPI app for the ADK web server. - - By default it'll just return a FastAPI instance with the API server - endpoints, - but if you specify a web_assets_dir, it'll also serve the static web assets - from that directory. - - Args: - lifespan: The lifespan of the FastAPI app. - allow_origins: The origins that are allowed to make cross-origin requests. - Entries can be literal origins (e.g., 'https://example.com') or regex - patterns prefixed with 'regex:' (e.g., - 'regex:https://.*\\.example\\.com'). - web_assets_dir: The directory containing the web assets to serve. - setup_observer: Callback for setting up the file system observer. - tear_down_observer: Callback for cleaning up the file system observer. - register_processors: Callback for additional Span processors to be added - to the TracerProvider. - otel_to_cloud: Whether to enable Cloud Trace and Cloud Logging - integrations. - - Returns: - A FastAPI app instance. - """ - # Properties we don't need to modify from callbacks - trace_dict = {} - session_trace_dict = {} - # Set up a file system watcher to detect changes in the agents directory. - observer = Observer() - setup_observer(observer, self) - - @asynccontextmanager - async def internal_lifespan(app: FastAPI): - try: - if lifespan: - async with lifespan(app) as lifespan_context: - yield lifespan_context - else: - yield - finally: - tear_down_observer(observer, self) - # Create tasks for all runner closures to run concurrently - await cleanup.close_runners(list(self.runner_dict.values())) - - memory_exporter = InMemoryExporter(session_trace_dict) - - _setup_telemetry( - otel_to_cloud=otel_to_cloud, - internal_exporters=[ - export_lib.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict)), - export_lib.SimpleSpanProcessor(memory_exporter), - ], - ) - if web_assets_dir: - self._setup_runtime_config(web_assets_dir) - - # TODO - register_processors to be removed once --otel_to_cloud is no - # longer experimental. - tracer_provider = trace.get_tracer_provider() - register_processors(tracer_provider) - - # Run the FastAPI server. - app = FastAPI(lifespan=internal_lifespan) - - has_configured_allowed_origins = bool(allow_origins) - if allow_origins: - literal_origins, combined_regex = _parse_cors_origins(allow_origins) - compiled_origin_regex = ( - re.compile(combined_regex) if combined_regex is not None else None - ) - app.add_middleware( - CORSMiddleware, - allow_origins=literal_origins, - allow_origin_regex=combined_regex, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - else: - literal_origins = [] - compiled_origin_regex = None - - app.add_middleware( - _OriginCheckMiddleware, - has_configured_allowed_origins=has_configured_allowed_origins, - allowed_origins=literal_origins, - allowed_origin_regex=compiled_origin_regex, - ) - - @app.get("/health") - async def health() -> dict[str, str]: - return {"status": "ok"} - - @app.get("/version") - async def version() -> dict[str, str]: - return { - "version": __version__, - "language": "python", - "language_version": ( - f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" - ), - } - - @app.get("/list-apps") - async def list_apps( - detailed: bool = Query( - default=False, description="Return detailed app information" - ) - ) -> list[str] | ListAppsResponse: - if detailed: - apps_info = self.agent_loader.list_agents_detailed() - return ListAppsResponse(apps=[AppInfo(**app) for app in apps_info]) - return self.agent_loader.list_agents() - - @experimental - @app.get("/apps/{app_name}/app-info", response_model_exclude_none=True) - async def get_adk_app_info(app_name: str) -> AppInfo: - """Returns the detailed info for a given ADK app.""" - agent_or_app = self.agent_loader.load_agent(app_name) - root_agent = self._get_root_agent(agent_or_app) - if isinstance(root_agent, LlmAgent): - return AppInfo( - name=app_name, - root_agent_name=root_agent.name, - description=root_agent.description, - language="python", - agents=get_agents_dict(root_agent), - ) - else: - raise HTTPException( - status_code=400, detail="Root agent is not an LlmAgent" - ) - - @app.get("/debug/trace/{event_id}", tags=[TAG_DEBUG]) - async def get_trace_dict(event_id: str) -> Any: - event_dict = trace_dict.get(event_id, None) - if event_dict is None: - raise HTTPException(status_code=404, detail="Trace not found") - return event_dict - - if web_assets_dir: - - @app.get("/dev/build_graph/{app_name}") - async def get_app_info(app_name: str) -> Any: - runner = await self.get_runner_async(app_name) - - if not runner.app: - raise HTTPException( - status_code=404, detail=f"App not found: {app_name}" - ) - - def serialize_agent(agent: BaseAgent) -> dict[str, Any]: - """Recursively serialize an agent, excluding non-serializable fields.""" - agent_dict = {} - - for field_name, field_info in agent.__class__.model_fields.items(): - # Skip non-serializable fields - if field_name in [ - "parent_agent", - "before_agent_callback", - "after_agent_callback", - "before_model_callback", - "after_model_callback", - "on_model_error_callback", - "before_tool_callback", - "after_tool_callback", - "on_tool_error_callback", - ]: - continue - - value = getattr(agent, field_name, None) - - # Handle sub_agents recursively - if field_name == "sub_agents" and value: - agent_dict[field_name] = [ - serialize_agent(sub_agent) for sub_agent in value - ] - elif value is None or field_name == "tools": - continue - else: - try: - if isinstance(value, (str, int, float, bool, list, dict)): - agent_dict[field_name] = value - elif hasattr(value, "model_dump"): - agent_dict[field_name] = value.model_dump( - mode="python", exclude_none=True - ) - else: - agent_dict[field_name] = str(value) - except Exception: - pass - - return agent_dict - - app_info = { - "name": runner.app.name, - "root_agent": serialize_agent(runner.app.root_agent), - } - - # Add optional fields if present - if runner.app.plugins: - app_info["plugins"] = [ - {"name": getattr(plugin, "name", type(plugin).__name__)} - for plugin in runner.app.plugins - ] - - if runner.app.context_cache_config: - try: - app_info["context_cache_config"] = ( - runner.app.context_cache_config.model_dump( - mode="python", exclude_none=True - ) - ) - except Exception: - pass - - if runner.app.resumability_config: - try: - app_info["resumability_config"] = ( - runner.app.resumability_config.model_dump( - mode="python", exclude_none=True - ) - ) - except Exception: - pass - - return app_info - - @app.get("/debug/trace/session/{session_id}", tags=[TAG_DEBUG]) - async def get_session_trace(session_id: str) -> Any: - spans = memory_exporter.get_finished_spans(session_id) - if not spans: - return [] - return [ - { - "name": s.name, - "span_id": s.context.span_id, - "trace_id": s.context.trace_id, - "start_time": s.start_time, - "end_time": s.end_time, - "attributes": dict(s.attributes), - "parent_span_id": s.parent.span_id if s.parent else None, - } - for s in spans - ] - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}", - response_model_exclude_none=True, - ) - async def get_session( - app_name: str, user_id: str, session_id: str - ) -> Session: - session = await self.session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - self.current_app_name_ref.value = app_name - return session - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions", - response_model_exclude_none=True, - ) - async def list_sessions(app_name: str, user_id: str) -> list[Session]: - list_sessions_response = await self.session_service.list_sessions( - app_name=app_name, user_id=user_id - ) - return [ - session - for session in list_sessions_response.sessions - # Remove sessions that were generated as a part of Eval. - if not session.id.startswith(EVAL_SESSION_ID_PREFIX) - ] - - @deprecated( - "Please use create_session instead. This will be removed in future" - " releases." - ) - @app.post( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}", - response_model_exclude_none=True, - ) - async def create_session_with_id( - app_name: str, - user_id: str, - session_id: str, - state: Optional[dict[str, Any]] = None, - ) -> Session: - return await self._create_session( - app_name=app_name, - user_id=user_id, - state=state, - session_id=session_id, - ) - - @app.post( - "/apps/{app_name}/users/{user_id}/sessions", - response_model_exclude_none=True, - ) - async def create_session( - app_name: str, - user_id: str, - req: Optional[CreateSessionRequest] = None, - ) -> Session: - if not req: - return await self._create_session(app_name=app_name, user_id=user_id) - - session = await self._create_session( - app_name=app_name, - user_id=user_id, - state=req.state, - session_id=req.session_id, - ) - - if req.events: - for event in req.events: - await self.session_service.append_event(session=session, event=event) - - return session - - @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") - async def delete_session( - app_name: str, user_id: str, session_id: str - ) -> None: - await self.session_service.delete_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - - @app.patch( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}", - response_model_exclude_none=True, - ) - async def update_session( - app_name: str, - user_id: str, - session_id: str, - req: UpdateSessionRequest, - ) -> Session: - """Updates session state without running the agent. - - Args: - app_name: The name of the application. - user_id: The ID of the user. - session_id: The ID of the session to update. - req: The patch request containing state changes. - - Returns: - The updated session. - - Raises: - HTTPException: If the session is not found. - """ - session = await self.session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - # Create an event to record the state change - import uuid - - from ..events.event import Event - from ..events.event import EventActions - - state_update_event = Event( - invocation_id="p-" + str(uuid.uuid4()), - author="user", - actions=EventActions(state_delta=req.state_delta), - ) - - # Append the event to the session - # This will automatically update the session state through __update_session_state - await self.session_service.append_event( - session=session, event=state_update_event - ) - - return session - - @app.post( - "/apps/{app_name}/eval-sets", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def create_eval_set( - app_name: str, create_eval_set_request: CreateEvalSetRequest - ) -> EvalSet: - try: - return self.eval_sets_manager.create_eval_set( - app_name=app_name, - eval_set_id=create_eval_set_request.eval_set.eval_set_id, - ) - except ValueError as ve: - raise HTTPException( - status_code=400, - detail=str(ve), - ) from ve - - # TODO - remove after migration - @deprecated( - "Please use create_eval_set instead. This will be removed in future" - " releases." - ) - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def create_eval_set_legacy( - app_name: str, - eval_set_id: str, - ): - """Creates an eval set, given the id.""" - await create_eval_set( - app_name=app_name, - create_eval_set_request=CreateEvalSetRequest( - eval_set=EvalSet(eval_set_id=eval_set_id, eval_cases=[]) - ), - ) - - # TODO - remove after migration - @deprecated( - "Please use list_eval_sets instead. This will be removed in future" - " releases." - ) - @app.get( - "/apps/{app_name}/eval_sets", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def list_eval_sets_legacy(app_name: str) -> list[str]: - list_eval_sets_response = await list_eval_sets(app_name) - return list_eval_sets_response.eval_set_ids - - # TODO - remove after migration - @deprecated( - "Please use run_eval instead. This will be removed in future releases." - ) - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def run_eval_legacy( - app_name: str, eval_set_id: str, req: RunEvalRequest - ) -> list[RunEvalResult]: - run_eval_response = await run_eval( - app_name=app_name, eval_set_id=eval_set_id, req=req - ) - return run_eval_response.run_eval_results - - # TODO - remove after migration - @deprecated( - "Please use get_eval_result instead. This will be removed in future" - " releases." - ) - @app.get( - "/apps/{app_name}/eval_results/{eval_result_id}", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def get_eval_result_legacy( - app_name: str, - eval_result_id: str, - ) -> EvalSetResult: - try: - return self.eval_set_results_manager.get_eval_set_result( - app_name, eval_result_id - ) - except ValueError as ve: - raise HTTPException(status_code=404, detail=str(ve)) from ve - except ValidationError as ve: - raise HTTPException(status_code=500, detail=str(ve)) from ve - - # TODO - remove after migration - @deprecated( - "Please use list_eval_results instead. This will be removed in future" - " releases." - ) - @app.get( - "/apps/{app_name}/eval_results", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def list_eval_results_legacy(app_name: str) -> list[str]: - list_eval_results_response = await list_eval_results(app_name) - return list_eval_results_response.eval_result_ids - - @app.get( - "/apps/{app_name}/eval-sets", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def list_eval_sets(app_name: str) -> ListEvalSetsResponse: - """Lists all eval sets for the given app.""" - eval_sets = [] - try: - eval_sets = self.eval_sets_manager.list_eval_sets(app_name) - except NotFoundError as e: - logger.warning(e) - - return ListEvalSetsResponse(eval_set_ids=eval_sets) - - @app.post( - "/apps/{app_name}/eval-sets/{eval_set_id}/add-session", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}/add_session", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def add_session_to_eval_set( - app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest - ): - # Get the session - session = await self.session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id - ) - assert session, "Session not found." - - # Convert the session data to eval invocations - invocations = evals.convert_session_to_eval_invocations(session) - - # Populate the session with initial session state. - agent_or_app = self.agent_loader.load_agent(app_name) - root_agent = self._get_root_agent(agent_or_app) - initial_session_state = create_empty_state(root_agent) - - new_eval_case = EvalCase( - eval_id=req.eval_id, - conversation=invocations, - session_input=SessionInput( - app_name=app_name, - user_id=req.user_id, - state=initial_session_state, - ), - creation_timestamp=time.time(), - ) - - try: - self.eval_sets_manager.add_eval_case( - app_name, eval_set_id, new_eval_case - ) - except ValueError as ve: - raise HTTPException(status_code=400, detail=str(ve)) from ve - - @app.get( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def list_evals_in_eval_set( - app_name: str, - eval_set_id: str, - ) -> list[str]: - """Lists all evals in an eval set.""" - eval_set_data = self.eval_sets_manager.get_eval_set(app_name, eval_set_id) - - if not eval_set_data: - raise HTTPException( - status_code=400, detail=f"Eval set `{eval_set_id}` not found." - ) - - return sorted([x.eval_id for x in eval_set_data.eval_cases]) - - @app.get( - "/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - @app.get( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def get_eval( - app_name: str, eval_set_id: str, eval_case_id: str - ) -> EvalCase: - """Gets an eval case in an eval set.""" - eval_case_to_find = self.eval_sets_manager.get_eval_case( - app_name, eval_set_id, eval_case_id - ) - - if eval_case_to_find: - return eval_case_to_find - - raise HTTPException( - status_code=404, - detail=( - f"Eval set `{eval_set_id}` or Eval `{eval_case_id}` not found." - ), - ) - - @app.put( - "/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - @app.put( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def update_eval( - app_name: str, - eval_set_id: str, - eval_case_id: str, - updated_eval_case: EvalCase, - ): - if ( - updated_eval_case.eval_id - and updated_eval_case.eval_id != eval_case_id - ): - raise HTTPException( - status_code=400, - detail=( - "Eval id in EvalCase should match the eval id in the API route." - ), - ) - - # Overwrite the value. We are either overwriting the same value or an empty - # field. - updated_eval_case.eval_id = eval_case_id - try: - self.eval_sets_manager.update_eval_case( - app_name, eval_set_id, updated_eval_case - ) - except NotFoundError as nfe: - raise HTTPException(status_code=404, detail=str(nfe)) from nfe - - @app.delete( - "/apps/{app_name}/eval-sets/{eval_set_id}/eval-cases/{eval_case_id}", - tags=[TAG_EVALUATION], - ) - @app.delete( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", - tags=[TAG_EVALUATION], - ) - async def delete_eval( - app_name: str, eval_set_id: str, eval_case_id: str - ) -> None: - try: - self.eval_sets_manager.delete_eval_case( - app_name, eval_set_id, eval_case_id - ) - except NotFoundError as nfe: - raise HTTPException(status_code=404, detail=str(nfe)) from nfe - - @app.post( - "/apps/{app_name}/eval-sets/{eval_set_id}/run", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def run_eval( - app_name: str, eval_set_id: str, req: RunEvalRequest - ) -> RunEvalResponse: - """Runs an eval given the details in the eval request.""" - # Create a mapping from eval set file to all the evals that needed to be - # run. - try: - from ..evaluation.local_eval_service import LocalEvalService - from .cli_eval import _collect_eval_results - from .cli_eval import _collect_inferences - - eval_set = self.eval_sets_manager.get_eval_set(app_name, eval_set_id) - - if not eval_set: - raise HTTPException( - status_code=400, detail=f"Eval set `{eval_set_id}` not found." - ) - - agent_or_app = self.agent_loader.load_agent(app_name) - root_agent = self._get_root_agent(agent_or_app) - - eval_case_results = [] - - eval_service = LocalEvalService( - root_agent=root_agent, - eval_sets_manager=self.eval_sets_manager, - eval_set_results_manager=self.eval_set_results_manager, - session_service=self.session_service, - artifact_service=self.artifact_service, - ) - inference_request = InferenceRequest( - app_name=app_name, - eval_set_id=eval_set.eval_set_id, - eval_case_ids=req.eval_case_ids or req.eval_ids, - inference_config=InferenceConfig(), - ) - inference_results = await _collect_inferences( - inference_requests=[inference_request], eval_service=eval_service - ) - - eval_case_results = await _collect_eval_results( - inference_results=inference_results, - eval_service=eval_service, - eval_metrics=req.eval_metrics, - ) - except ModuleNotFoundError as e: - logger.exception("%s", e) - raise HTTPException( - status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE - ) from e - - run_eval_results = [] - for eval_case_result in eval_case_results: - run_eval_results.append( - RunEvalResult( - eval_set_file=eval_case_result.eval_set_file, - eval_set_id=eval_set_id, - eval_id=eval_case_result.eval_id, - final_eval_status=eval_case_result.final_eval_status, - overall_eval_metric_results=eval_case_result.overall_eval_metric_results, - eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, - ) - ) - - return RunEvalResponse(run_eval_results=run_eval_results) - - @app.get( - "/apps/{app_name}/eval-results/{eval_result_id}", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def get_eval_result( - app_name: str, - eval_result_id: str, - ) -> EvalResult: - """Gets the eval result for the given eval id.""" - try: - eval_set_result = self.eval_set_results_manager.get_eval_set_result( - app_name, eval_result_id - ) - return EvalResult(**eval_set_result.model_dump()) - except ValueError as ve: - raise HTTPException(status_code=404, detail=str(ve)) from ve - except ValidationError as ve: - raise HTTPException(status_code=500, detail=str(ve)) from ve - - @app.get( - "/apps/{app_name}/eval-results", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def list_eval_results(app_name: str) -> ListEvalResultsResponse: - """Lists all eval results for the given app.""" - eval_result_ids = self.eval_set_results_manager.list_eval_set_results( - app_name - ) - return ListEvalResultsResponse(eval_result_ids=eval_result_ids) - - @app.get( - "/apps/{app_name}/metrics-info", - response_model_exclude_none=True, - tags=[TAG_EVALUATION], - ) - async def list_metrics_info(app_name: str) -> ListMetricsInfoResponse: - """Lists all eval metrics for the given app.""" - try: - from ..evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY - - # Right now we ignore the app_name as eval metrics are not tied to the - # app_name, but they could be moving forward. - metrics_info = ( - DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics() - ) - return ListMetricsInfoResponse(metrics_info=metrics_info) - except ModuleNotFoundError as e: - logger.exception("%s\n%s", MISSING_EVAL_DEPENDENCIES_MESSAGE, e) - raise HTTPException( - status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE - ) from e - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", - response_model_exclude_none=True, - ) - async def load_artifact( - app_name: str, - user_id: str, - session_id: str, - artifact_name: str, - version: Optional[int] = Query(None), - ) -> Optional[types.Part]: - artifact = await self.artifact_service.load_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - version=version, - ) - if not artifact: - raise HTTPException(status_code=404, detail="Artifact not found") - return artifact - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/metadata", - response_model=list[ArtifactVersion], - response_model_exclude_none=True, - ) - async def list_artifact_versions_metadata( - app_name: str, - user_id: str, - session_id: str, - artifact_name: str, - ) -> list[ArtifactVersion]: - return await self.artifact_service.list_artifact_versions( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - ) - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", - response_model_exclude_none=True, - ) - async def load_artifact_version( - app_name: str, - user_id: str, - session_id: str, - artifact_name: str, - version_id: int, - ) -> Optional[types.Part]: - artifact = await self.artifact_service.load_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - version=version_id, - ) - if not artifact: - raise HTTPException(status_code=404, detail="Artifact not found") - return artifact - - @app.post( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", - response_model=ArtifactVersion, - response_model_exclude_none=True, - ) - async def save_artifact( - app_name: str, - user_id: str, - session_id: str, - req: SaveArtifactRequest, - ) -> ArtifactVersion: - try: - version = await self.artifact_service.save_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=req.filename, - artifact=req.artifact, - custom_metadata=req.custom_metadata, - ) - except InputValidationError as ive: - raise HTTPException(status_code=400, detail=str(ive)) from ive - except Exception as exc: # pylint: disable=broad-exception-caught - logger.error( - "Internal error while saving artifact %s for app=%s user=%s" - " session=%s: %s", - req.filename, - app_name, - user_id, - session_id, - exc, - exc_info=True, - ) - raise HTTPException(status_code=500, detail=str(exc)) from exc - artifact_version = await self.artifact_service.get_artifact_version( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=req.filename, - version=version, - ) - if artifact_version is None: - raise HTTPException( - status_code=500, detail="Artifact metadata unavailable" - ) - return artifact_version - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}/metadata", - response_model=ArtifactVersion, - response_model_exclude_none=True, - ) - async def get_artifact_version_metadata( - app_name: str, - user_id: str, - session_id: str, - artifact_name: str, - version_id: int, - ) -> ArtifactVersion: - artifact_version = await self.artifact_service.get_artifact_version( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - version=version_id, - ) - if not artifact_version: - raise HTTPException( - status_code=404, detail="Artifact version not found" - ) - return artifact_version - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", - response_model_exclude_none=True, - ) - async def list_artifact_names( - app_name: str, user_id: str, session_id: str - ) -> list[str]: - return await self.artifact_service.list_artifact_keys( - app_name=app_name, user_id=user_id, session_id=session_id - ) - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", - response_model_exclude_none=True, - ) - async def list_artifact_versions( - app_name: str, user_id: str, session_id: str, artifact_name: str - ) -> list[int]: - return await self.artifact_service.list_versions( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - ) - - @app.delete( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", - ) - async def delete_artifact( - app_name: str, user_id: str, session_id: str, artifact_name: str - ) -> None: - await self.artifact_service.delete_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - ) - - @app.patch("/apps/{app_name}/users/{user_id}/memory") - async def patch_memory( - app_name: str, user_id: str, update_memory_request: UpdateMemoryRequest - ) -> None: - """Adds all events from a given session to the memory service. - - Args: - app_name: The name of the application. - user_id: The ID of the user. - update_memory_request: The memory request for the update - - Raises: - HTTPException: If the memory service is not configured or the request - is invalid. - """ - if not self.memory_service: - raise HTTPException( - status_code=400, detail="Memory service is not configured." - ) - if ( - update_memory_request is None - or update_memory_request.session_id is None - ): - raise HTTPException( - status_code=400, detail="Update memory request is invalid." - ) - - session = await self.session_service.get_session( - app_name=app_name, - user_id=user_id, - session_id=update_memory_request.session_id, - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - await self.memory_service.add_session_to_memory(session) - - def _set_telemetry_context_if_needed(runner: Runner): - """Helper to set contextvars for the current request task.""" - app = getattr(runner, "app", None) - from ..utils._telemetry_context import _is_visual_builder - - if app and getattr(app, "_is_visual_builder_app", False): - _is_visual_builder.set(True) - else: - _is_visual_builder.set(False) - - @app.post("/run", response_model_exclude_none=True) - async def run_agent(req: RunAgentRequest) -> list[Event]: - self.current_app_name_ref.value = req.app_name - runner = await self.get_runner_async(req.app_name) - _set_telemetry_context_if_needed(runner) - run_config = ( - RunConfig(custom_metadata=req.custom_metadata) - if req.custom_metadata - else None - ) - try: - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - invocation_id=req.invocation_id, - run_config=run_config, - ) - ) as agen: - events = [event async for event in agen] - except SessionNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - logger.info("Generated %s events in agent run", len(events)) - logger.debug("Events generated: %s", events) - return events - - @app.post("/run_sse") - async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: - self.current_app_name_ref.value = req.app_name - stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE - runner = await self.get_runner_async(req.app_name) - _set_telemetry_context_if_needed(runner) - - # Validate session existence before starting the stream. - # We check directly here instead of eagerly advancing the - # runner's async generator with anext(), because splitting - # generator consumption across two asyncio Tasks (request - # handler vs StreamingResponse) breaks OpenTelemetry context - # detachment. - if not runner.auto_create_session: - session = await self.session_service.get_session( - app_name=req.app_name, - user_id=req.user_id, - session_id=req.session_id, - ) - if not session: - raise HTTPException( - status_code=404, - detail=f"Session not found: {req.session_id}", - ) - - # Convert the events to properly formatted SSE - async def event_generator(): - run_config = RunConfig( - streaming_mode=stream_mode, - custom_metadata=req.custom_metadata, - ) - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=run_config, - invocation_id=req.invocation_id, - ) - ) as agen: - try: - async for event in agen: - # ADK Web renders artifacts from `actions.artifactDelta` - # during part processing *and* during action processing - # 1) the original event with `artifactDelta` cleared (content) - # 2) a content-less "action-only" event carrying `artifactDelta` - events_to_stream = [event] - if ( - not req.function_call_event_id - and event.actions.artifact_delta - and event.content - and event.content.parts - ): - content_event = event.model_copy(deep=True) - content_event.actions.artifact_delta = {} - artifact_event = event.model_copy(deep=True) - artifact_event.content = None - events_to_stream = [content_event, artifact_event] - - for event_to_stream in events_to_stream: - sse_event = event_to_stream.model_dump_json( - exclude_none=True, - by_alias=True, - ) - logger.debug( - "Generated event in agent run streaming: %s", sse_event - ) - yield f"data: {sse_event}\n\n" - except Exception as e: - logger.exception("Error in event_generator: %s", e) - yield f"data: {json.dumps({'error': str(e)})}\n\n" - - # Returns a streaming response with the proper media type for SSE - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - ) - - @app.get( - "/dev/{app_name}/graph", - response_model_exclude_none=True, - tags=[TAG_DEBUG], - ) - async def get_app_graph_dot( - app_name: str, dark_mode: bool = False - ) -> GetEventGraphResult | dict: - """Returns the base agent graph in DOT format without any highlights. - - This endpoint allows the frontend to fetch the graph structure once - and compute highlights client-side for better performance. - - Args: - app_name: The name of the agent/app - dark_mode: Whether to use dark theme background color - """ - agent_or_app = self.agent_loader.load_agent(app_name) - root_agent = self._get_root_agent(agent_or_app) - - # Get graph with NO highlights (empty list) and specified theme - dot_graph = await agent_graph.get_agent_graph( - root_agent, [], dark_mode=dark_mode - ) - - if dot_graph and isinstance(dot_graph, graphviz.Digraph): - return GetEventGraphResult(dot_src=dot_graph.source) - else: - return {} - - # TODO: This endpoint can be removed once we update adk web to stop consuming it - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph", - response_model_exclude_none=True, - tags=[TAG_DEBUG], - ) - async def get_event_graph( - app_name: str, user_id: str, session_id: str, event_id: str - ): - session = await self.session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - session_events = session.events if session else [] - event = next((x for x in session_events if x.id == event_id), None) - if not event: - return {} - - function_calls = event.get_function_calls() - function_responses = event.get_function_responses() - agent_or_app = self.agent_loader.load_agent(app_name) - root_agent = self._get_root_agent(agent_or_app) - dot_graph = None - if function_calls: - function_call_highlights = [] - for function_call in function_calls: - from_name = event.author - to_name = function_call.name - function_call_highlights.append((from_name, to_name)) - dot_graph = await agent_graph.get_agent_graph( - root_agent, function_call_highlights - ) - elif function_responses: - function_responses_highlights = [] - for function_response in function_responses: - from_name = function_response.name - to_name = event.author - function_responses_highlights.append((from_name, to_name)) - dot_graph = await agent_graph.get_agent_graph( - root_agent, function_responses_highlights - ) - else: - from_name = event.author - to_name = "" - dot_graph = await agent_graph.get_agent_graph( - root_agent, [(from_name, to_name)] - ) - if dot_graph and isinstance(dot_graph, graphviz.Digraph): - return GetEventGraphResult(dot_src=dot_graph.source) - else: - return {} - - @app.websocket("/run_live") - async def run_agent_live( - websocket: WebSocket, - app_name: str, - user_id: str, - session_id: str, - modalities: List[Literal["TEXT", "AUDIO"]] = Query( - default=["AUDIO"] - ), # Only allows "TEXT" or "AUDIO" - proactive_audio: bool | None = Query(default=None), - enable_affective_dialog: bool | None = Query(default=None), - enable_session_resumption: bool | None = Query(default=None), - save_live_blob: bool = Query(default=False), - ) -> None: - ws_origin = websocket.headers.get("origin") - if ws_origin is not None and not _is_request_origin_allowed( - ws_origin, - websocket.scope, - literal_origins, - compiled_origin_regex, - has_configured_allowed_origins, - ): - await websocket.close(code=1008, reason="Origin not allowed") - return - - await websocket.accept() - self.current_app_name_ref.value = app_name - runner_for_context = await self.get_runner_async(app_name) - _set_telemetry_context_if_needed(runner_for_context) - - session = await self.session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - if not session: - await websocket.close(code=1002, reason="Session not found") - return - - live_request_queue = LiveRequestQueue() - - async def forward_events(): - runner = await self.get_runner_async(app_name) - run_config = RunConfig( - response_modalities=modalities, - proactivity=( - types.ProactivityConfig(proactive_audio=proactive_audio) - if proactive_audio is not None - else None - ), - enable_affective_dialog=enable_affective_dialog, - session_resumption=( - types.SessionResumptionConfig( - transparent=enable_session_resumption - ) - if enable_session_resumption is not None - else None - ), - save_live_blob=save_live_blob, - ) - async with Aclosing( - runner.run_live( - session=session, - live_request_queue=live_request_queue, - run_config=run_config, - ) - ) as agen: - async for event in agen: - await websocket.send_text( - event.model_dump_json(exclude_none=True, by_alias=True) - ) - - async def process_messages(): - try: - while True: - data = await websocket.receive_text() - # Validate and send the received message to the live queue. - live_request_queue.send(LiveRequest.model_validate_json(data)) - except ValidationError as ve: - logger.error("Validation error in process_messages: %s", ve) - - # Run both tasks concurrently and cancel all if one fails. - tasks = [ - asyncio.create_task(forward_events()), - asyncio.create_task(process_messages()), - ] - done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_EXCEPTION - ) - try: - # This will re-raise any exception from the completed tasks. - for task in done: - task.result() - except WebSocketDisconnect: - # Disconnection could happen when receive or send text via websocket - logger.info("Client disconnected during live session.") - except Exception as e: - logger.exception("Error during live websocket communication: %s", e) - traceback.print_exc() - WEBSOCKET_INTERNAL_ERROR_CODE = 1011 - WEBSOCKET_MAX_BYTES_FOR_REASON = 123 - await websocket.close( - code=WEBSOCKET_INTERNAL_ERROR_CODE, - reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON], - ) - finally: - for task in pending: - task.cancel() - - # Register /trigger/* endpoints when enabled. - if self.trigger_sources: - from .trigger_routes import TriggerRouter - - trigger_router = TriggerRouter(self, trigger_sources=self.trigger_sources) - trigger_router.register(app) - - if web_assets_dir: - import mimetypes - - mimetypes.add_type("application/javascript", ".js", True) - mimetypes.add_type("text/javascript", ".js", True) - - redirect_dev_ui_url = ( - self.url_prefix + "/dev-ui/" if self.url_prefix else "/dev-ui/" - ) - - @app.get("/dev-ui/config") - async def get_ui_config(): - return { - "logo_text": self.logo_text, - "logo_image_url": self.logo_image_url, - } - - @app.get("/") - async def redirect_root_to_dev_ui(): - return RedirectResponse(redirect_dev_ui_url) - - @app.get("/dev-ui") - async def redirect_dev_ui_add_slash(): - return RedirectResponse(redirect_dev_ui_url) - - app.mount( - "/dev-ui/", - StaticFiles(directory=web_assets_dir, html=True, follow_symlink=True), - name="static", - ) - - return app + pass From db928b27ede482508cced6b6dfa5dcc05cc30e18 Mon Sep 17 00:00:00 2001 From: Yee Sian Ng Date: Wed, 3 Jun 2026 22:18:47 +0000 Subject: [PATCH 18/18] chore: revert all changes to outdated files --- src/google/adk/cli/utils/_telemetry.py | 106 ----------------------- src/google/adk/cli/utils/agent_loader.py | 48 +--------- 2 files changed, 1 insertion(+), 153 deletions(-) delete mode 100644 src/google/adk/cli/utils/_telemetry.py diff --git a/src/google/adk/cli/utils/_telemetry.py b/src/google/adk/cli/utils/_telemetry.py deleted file mode 100644 index 070cb45526..0000000000 --- a/src/google/adk/cli/utils/_telemetry.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Mapping -from typing import Optional - -import fastapi -from opentelemetry import baggage -from opentelemetry import context -from opentelemetry.sdk import trace -from opentelemetry.trace.propagation import tracecontext - -_GOOGLE_AE_TRACEPARENT_HEADER = "Google-Agent-Engine-Traceparent" -_TRACEPARENT_BAGGAGE_KEY = "traceparent" -_GOOGLE_TRACEPARENT_HEADER = "traceparent" -_GOOGLE_TRACEPARENT_BAGGAGE_KEY = "google_traceparent" -_GOOGLE_TRACEPARENT_SUPPORT_ATTRIBUTE_KEY = "supportID" - - -def get_propagated_context(request: fastapi.Request) -> context.Context: - """Propagates context from the request headers.""" - ctx = context.get_current() - - if _GOOGLE_TRACEPARENT_HEADER in request.headers: - original_traceparent = request.headers[_GOOGLE_TRACEPARENT_HEADER] - ctx = baggage.set_baggage( - _GOOGLE_TRACEPARENT_BAGGAGE_KEY, - original_traceparent, - context=ctx, - ) - - if _GOOGLE_AE_TRACEPARENT_HEADER in request.headers: - carrier = {"traceparent": request.headers[_GOOGLE_AE_TRACEPARENT_HEADER]} - ctx = baggage.set_baggage( - _TRACEPARENT_BAGGAGE_KEY, - request.headers[_GOOGLE_AE_TRACEPARENT_HEADER], - context=ctx, - ) - ctx = tracecontext.TraceContextTextMapPropagator().extract( - carrier=carrier, context=ctx - ) - - return ctx - - -class TopSpanProcessor(trace.SpanProcessor): - """Top span processor.""" - - def on_start( - self, span: trace.Span, parent_context: Optional[context.Context] = None - ): - """Adds support ID to the top span.""" - baggage_items = baggage.get_all(context=parent_context) - if self._is_top_span(span, baggage_items) and ( - baggage_trace_header := baggage_items.get( - _GOOGLE_TRACEPARENT_BAGGAGE_KEY - ) - ): - span.set_attribute( - _GOOGLE_TRACEPARENT_SUPPORT_ATTRIBUTE_KEY, baggage_trace_header - ) - - def on_end(self, span: trace.ReadableSpan) -> None: - pass - - def shutdown(self) -> None: - pass - - def force_flush(self, timeout_millis: int = 30000) -> bool: - return True - - def _is_top_span( - self, span: trace.Span, baggage_items: Mapping[str, object] - ) -> bool: - """Returns true if the span is a top span. - - Args: - span: The span to check. - baggage_items: The baggage items that carry the context. - - Top span (e.g. "Invocation" span) is defined as the first span generated in - trace generation. - Top span could have an empty parent or the parent could be the span - provided by traceparent propagation. - """ - if span.parent is None or span.parent.span_id == 0: - return True - if _TRACEPARENT_BAGGAGE_KEY in baggage_items: - parent_id_hex = str(baggage_items[_TRACEPARENT_BAGGAGE_KEY]).split("-")[2] - parent_id_int = int(parent_id_hex, 16) - if span.parent.span_id == parent_id_int: - return True - return False diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py index c1cb2d03e8..d4bbfc88f6 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/agent_loader.py @@ -39,16 +39,6 @@ logger = logging.getLogger("google_adk." + __name__) - -def is_single_agent_directory(path: Path | str) -> bool: - """Returns True if the directory contains a single agent configuration or file.""" - p = Path(path).resolve() - return ( - p.joinpath("agent.py").is_file() - or p.joinpath("root_agent.yaml").is_file() - ) - - # Special agents directory for agents with names starting with double underscore SPECIAL_AGENTS_DIR = os.path.join( os.path.dirname(__file__), "..", "built_in_agents" @@ -70,36 +60,10 @@ class AgentLoader(BaseAgentLoader): """ def __init__(self, agents_dir: str): - agents_path = Path(agents_dir).resolve() - is_single_agent = is_single_agent_directory(agents_path) - if is_single_agent: - self._is_single_agent = True - self._single_agent_name = agents_path.name - self.agents_dir = str(agents_path.parent) - else: - self._is_single_agent = False - self._single_agent_name = None - self.agents_dir = str(agents_path) - + self.agents_dir = str(Path(agents_dir)) self._original_sys_path = None self._agent_cache: dict[str, Union[BaseAgent, App]] = {} - @property - def is_single_agent(self) -> bool: - """Returns True if the loader is in single agent mode.""" - return self._is_single_agent - - @property - def single_agent_name(self) -> Optional[str]: - """Returns the name of the agent in single agent mode.""" - return self._single_agent_name - - def _set_single_agent_mode(self, name: str, agents_dir: str) -> None: - """Internal method to force single agent mode. Use with care.""" - self._is_single_agent = True - self._single_agent_name = name - self.agents_dir = agents_dir - def _load_from_module_or_package( self, agent_name: str ) -> Optional[Union[BaseAgent, App]]: @@ -240,13 +204,6 @@ def _validate_agent_name(self, agent_name: str) -> None: name_to_check = agent_name check_dir = self.agents_dir - if self._is_single_agent and not agent_name.startswith("__"): - if agent_name != self._single_agent_name: - raise ValueError( - f"Agent not found: {agent_name!r}. In single agent mode, only " - f"'{self._single_agent_name}' is accessible." - ) - if not self._VALID_AGENT_NAME_RE.match(name_to_check): raise ValueError( f"Invalid agent name: {agent_name!r}. Agent names must be valid" @@ -411,8 +368,6 @@ def load_agent(self, agent_name: str) -> Union[BaseAgent, App]: @override def list_agents(self) -> list[str]: """Lists all agents available in the agent loader (sorted alphabetically).""" - if self._is_single_agent: - return [self._single_agent_name] base_path = Path.cwd() / self.agents_dir agent_names = [ x @@ -484,4 +439,3 @@ def remove_agent_from_cache(self, agent_name: str): logger.debug("Deleting module %s", key) del sys.modules[key] self._agent_cache.pop(agent_name, None) - \ No newline at end of file