diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d6cc9ffb3..b29d7f3ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: runs-on: ${{ github.repository == 'stainless-sdks/agentex-sdk-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} if: (github.event_name == 'push' || github.event.pull_request.head.repo.fork) && (github.event_name != 'push' || github.event.head_commit.message != 'codegen metadata') steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install Rye run: | @@ -46,7 +46,7 @@ jobs: id-token: write runs-on: ${{ github.repository == 'stainless-sdks/agentex-sdk-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install Rye run: | @@ -67,7 +67,7 @@ jobs: github.repository == 'stainless-sdks/agentex-sdk-python' && !startsWith(github.ref, 'refs/heads/stl/') id: github-oidc - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: core.setOutput('github_token', await core.getIDToken()); @@ -87,7 +87,7 @@ jobs: runs-on: ${{ github.repository == 'stainless-sdks/agentex-sdk-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} if: github.event_name == 'push' || github.event.pull_request.head.repo.fork steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install Rye run: | diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 6b83ad2aa..864901da6 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install Rye run: | diff --git a/.github/workflows/release-doctor.yml b/.github/workflows/release-doctor.yml index dd2aefa66..a20022ce7 100644 --- a/.github/workflows/release-doctor.yml +++ b/.github/workflows/release-doctor.yml @@ -12,7 +12,7 @@ jobs: if: github.repository == 'scaleapi/scale-agentex-python' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || startsWith(github.head_ref, 'release-please') || github.head_ref == 'next') steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Check release environment run: | diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 78e7f271d..8032c17e8 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.11.0" + ".": "0.12.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d9dc1f90..b7abdeff1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Changelog +## 0.12.0 (2026-05-13) + +Full Changelog: [v0.11.0...v0.12.0](https://github.com/scaleapi/scale-agentex-python/compare/v0.11.0...v0.12.0) + +### ⚠ BREAKING CHANGES + +* remove AgentexTracingProcessor from default tracing processors ([#349](https://github.com/scaleapi/scale-agentex-python/issues/349)) + +### Features + +* **internal/types:** support eagerly validating pydantic iterators ([2c528c6](https://github.com/scaleapi/scale-agentex-python/commit/2c528c6db24cb64b7fffadafe3e8c46f316f0d56)) +* remove AgentexTracingProcessor from default tracing processors ([#349](https://github.com/scaleapi/scale-agentex-python/issues/349)) ([73eca7a](https://github.com/scaleapi/scale-agentex-python/commit/73eca7ad620a7e0a8bd0180b9dee02a7dde12dbb)) +* **streaming:** emit OTel metrics for ttft, tps, token counts ([#347](https://github.com/scaleapi/scale-agentex-python/issues/347)) ([3bf7d1f](https://github.com/scaleapi/scale-agentex-python/commit/3bf7d1f32f95e1346cdc823e3d1f4f027635e2dd)) + + +### Bug Fixes + +* **client:** add missing f-string prefix in file type error message ([dcb1cb4](https://github.com/scaleapi/scale-agentex-python/commit/dcb1cb489bc565828c16c327c5ab6b678b13c0fa)) +* render .env.example template in agentex init ([#351](https://github.com/scaleapi/scale-agentex-python/issues/351)) ([6092595](https://github.com/scaleapi/scale-agentex-python/commit/6092595fa8a267b2c305baba09e2682c04d593b3)) +* **tracing:** make SGP processor stateless to stop dropping span closes ([#354](https://github.com/scaleapi/scale-agentex-python/issues/354)) ([5e9f28d](https://github.com/scaleapi/scale-agentex-python/commit/5e9f28d2f1453b3b6faf993acf9f67a6fd098952)) +* wire SGP_CLIENT_BASE_URL and silence openai-agents tracer in templates ([#352](https://github.com/scaleapi/scale-agentex-python/issues/352)) ([870324e](https://github.com/scaleapi/scale-agentex-python/commit/870324e7bb87cefc20a79dc344d8603a836ca9b5)) + ## 0.11.0 (2026-05-07) Full Changelog: [v0.10.5...v0.11.0](https://github.com/scaleapi/scale-agentex-python/compare/v0.10.5...v0.11.0) diff --git a/pyproject.toml b/pyproject.toml index 547fc9cf9..aff030d70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agentex-sdk" -version = "0.11.0" +version = "0.12.0" description = "The official Python library for the agentex API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/agentex/_files.py b/src/agentex/_files.py index 0fdce17bf..76da9e085 100644 --- a/src/agentex/_files.py +++ b/src/agentex/_files.py @@ -99,7 +99,7 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles elif is_sequence_t(files): files = [(key, await _async_transform_file(file)) for key, file in files] else: - raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence") + raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") return files diff --git a/src/agentex/_models.py b/src/agentex/_models.py index 29070e055..8c5ab2602 100644 --- a/src/agentex/_models.py +++ b/src/agentex/_models.py @@ -25,7 +25,9 @@ ClassVar, Protocol, Required, + Annotated, ParamSpec, + TypeAlias, TypedDict, TypeGuard, final, @@ -79,7 +81,15 @@ from ._constants import RAW_RESPONSE_HEADER if TYPE_CHECKING: + from pydantic import GetCoreSchemaHandler, ValidatorFunctionWrapHandler + from pydantic_core import CoreSchema, core_schema from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema +else: + try: + from pydantic_core import CoreSchema, core_schema + except ImportError: + CoreSchema = None + core_schema = None __all__ = ["BaseModel", "GenericModel"] @@ -396,6 +406,76 @@ def model_dump_json( ) +class _EagerIterable(list[_T], Generic[_T]): + """ + Accepts any Iterable[T] input (including generators), consumes it + eagerly, and validates all items upfront. + + Validation preserves the original container type where possible + (e.g. a set[T] stays a set[T]). Serialization (model_dump / JSON) + always emits a list — round-tripping through model_dump() will not + restore the original container type. + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, + handler: GetCoreSchemaHandler, + ) -> CoreSchema: + (item_type,) = get_args(source_type) or (Any,) + item_schema: CoreSchema = handler.generate_schema(item_type) + list_of_items_schema: CoreSchema = core_schema.list_schema(item_schema) + + return core_schema.no_info_wrap_validator_function( + cls._validate, + list_of_items_schema, + serialization=core_schema.plain_serializer_function_ser_schema( + cls._serialize, + info_arg=False, + ), + ) + + @staticmethod + def _validate(v: Iterable[_T], handler: "ValidatorFunctionWrapHandler") -> Any: + original_type: type[Any] = type(v) + + # Normalize to list so list_schema can validate each item + if isinstance(v, list): + items: list[_T] = v + else: + try: + items = list(v) + except TypeError as e: + raise TypeError("Value is not iterable") from e + + # Validate items against the inner schema + validated: list[_T] = handler(items) + + # Reconstruct original container type + if original_type is list: + return validated + # str(list) produces the list's repr, not a string built from items, + # so skip reconstruction for str and its subclasses. + if issubclass(original_type, str): + return validated + try: + return original_type(validated) + except (TypeError, ValueError): + # If the type cannot be reconstructed, just return the validated list + return validated + + @staticmethod + def _serialize(v: Iterable[_T]) -> list[_T]: + """Always serialize as a list so Pydantic's JSON encoder is happy.""" + if isinstance(v, list): + return v + return list(v) + + +EagerIterable: TypeAlias = Annotated[Iterable[_T], _EagerIterable] + + def _construct_field(value: object, field: FieldInfo, key: str) -> object: if value is None: return field_get_default(field) diff --git a/src/agentex/_version.py b/src/agentex/_version.py index 59720802a..80a336119 100644 --- a/src/agentex/_version.py +++ b/src/agentex/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "agentex" -__version__ = "0.11.0" # x-release-please-version +__version__ = "0.12.0" # x-release-please-version diff --git a/src/agentex/lib/cli/templates/default-langgraph/.env.example.j2 b/src/agentex/lib/cli/templates/default-langgraph/.env.example.j2 index 1e81b15dd..015f49ef7 100644 --- a/src/agentex/lib/cli/templates/default-langgraph/.env.example.j2 +++ b/src/agentex/lib/cli/templates/default-langgraph/.env.example.j2 @@ -10,3 +10,4 @@ LITELLM_API_KEY= # SGP Configuration (optional - for tracing) # SGP_API_KEY= # SGP_ACCOUNT_ID= +# SGP_CLIENT_BASE_URL= diff --git a/src/agentex/lib/cli/templates/default/.env.example.j2 b/src/agentex/lib/cli/templates/default/.env.example.j2 index 1e81b15dd..015f49ef7 100644 --- a/src/agentex/lib/cli/templates/default/.env.example.j2 +++ b/src/agentex/lib/cli/templates/default/.env.example.j2 @@ -10,3 +10,4 @@ LITELLM_API_KEY= # SGP Configuration (optional - for tracing) # SGP_API_KEY= # SGP_ACCOUNT_ID= +# SGP_CLIENT_BASE_URL= diff --git a/src/agentex/lib/cli/templates/sync-langgraph/.env.example.j2 b/src/agentex/lib/cli/templates/sync-langgraph/.env.example.j2 index 1e81b15dd..015f49ef7 100644 --- a/src/agentex/lib/cli/templates/sync-langgraph/.env.example.j2 +++ b/src/agentex/lib/cli/templates/sync-langgraph/.env.example.j2 @@ -10,3 +10,4 @@ LITELLM_API_KEY= # SGP Configuration (optional - for tracing) # SGP_API_KEY= # SGP_ACCOUNT_ID= +# SGP_CLIENT_BASE_URL= diff --git a/src/agentex/lib/cli/templates/sync-openai-agents/.env.example.j2 b/src/agentex/lib/cli/templates/sync-openai-agents/.env.example.j2 index 1e81b15dd..015f49ef7 100644 --- a/src/agentex/lib/cli/templates/sync-openai-agents/.env.example.j2 +++ b/src/agentex/lib/cli/templates/sync-openai-agents/.env.example.j2 @@ -10,3 +10,4 @@ LITELLM_API_KEY= # SGP Configuration (optional - for tracing) # SGP_API_KEY= # SGP_ACCOUNT_ID= +# SGP_CLIENT_BASE_URL= diff --git a/src/agentex/lib/cli/templates/sync-openai-agents/project/acp.py.j2 b/src/agentex/lib/cli/templates/sync-openai-agents/project/acp.py.j2 index 2295eccb3..0b3b482fe 100644 --- a/src/agentex/lib/cli/templates/sync-openai-agents/project/acp.py.j2 +++ b/src/agentex/lib/cli/templates/sync-openai-agents/project/acp.py.j2 @@ -13,7 +13,12 @@ from agentex.types.task_message_update import TaskMessageUpdate, StreamTaskMessa from agentex.types.task_message_content import TaskMessageContent from agentex.types.text_content import TextContent from agentex.lib.utils.logging import make_logger -from agents import Agent, Runner, RunConfig, function_tool +from agents import Agent, Runner, RunConfig, function_tool, set_tracing_disabled + +# Disable the openai-agents SDK's native tracer so it doesn't ship traces to +# api.openai.com using OPENAI_API_KEY (which may be a LiteLLM proxy key). +# SGP tracing below still runs via the Agentex tracing manager. +set_tracing_disabled(True) logger = make_logger(__name__) @@ -25,12 +30,14 @@ if _litellm_key: SGP_API_KEY = os.environ.get("SGP_API_KEY", "") SGP_ACCOUNT_ID = os.environ.get("SGP_ACCOUNT_ID", "") +SGP_CLIENT_BASE_URL = os.environ.get("SGP_CLIENT_BASE_URL", "") if SGP_API_KEY and SGP_ACCOUNT_ID: add_tracing_processor_config( SGPTracingProcessorConfig( sgp_api_key=SGP_API_KEY, sgp_account_id=SGP_ACCOUNT_ID, + sgp_base_url=SGP_CLIENT_BASE_URL, ) ) diff --git a/src/agentex/lib/cli/templates/sync/.env.example.j2 b/src/agentex/lib/cli/templates/sync/.env.example.j2 index 1e81b15dd..015f49ef7 100644 --- a/src/agentex/lib/cli/templates/sync/.env.example.j2 +++ b/src/agentex/lib/cli/templates/sync/.env.example.j2 @@ -10,3 +10,4 @@ LITELLM_API_KEY= # SGP Configuration (optional - for tracing) # SGP_API_KEY= # SGP_ACCOUNT_ID= +# SGP_CLIENT_BASE_URL= diff --git a/src/agentex/lib/cli/templates/temporal-openai-agents/.env.example.j2 b/src/agentex/lib/cli/templates/temporal-openai-agents/.env.example.j2 index 1e81b15dd..015f49ef7 100644 --- a/src/agentex/lib/cli/templates/temporal-openai-agents/.env.example.j2 +++ b/src/agentex/lib/cli/templates/temporal-openai-agents/.env.example.j2 @@ -10,3 +10,4 @@ LITELLM_API_KEY= # SGP Configuration (optional - for tracing) # SGP_API_KEY= # SGP_ACCOUNT_ID= +# SGP_CLIENT_BASE_URL= diff --git a/src/agentex/lib/cli/templates/temporal-openai-agents/project/workflow.py.j2 b/src/agentex/lib/cli/templates/temporal-openai-agents/project/workflow.py.j2 index 94b5221fb..5b95d4479 100644 --- a/src/agentex/lib/cli/templates/temporal-openai-agents/project/workflow.py.j2 +++ b/src/agentex/lib/cli/templates/temporal-openai-agents/project/workflow.py.j2 @@ -10,7 +10,13 @@ from agentex.lib.core.temporal.types.workflow import SignalName from agentex.lib.utils.logging import make_logger from agentex.types.text_content import TextContent from agentex.lib.environment_variables import EnvironmentVariables -from agents import Agent, Runner +from agents import Agent, Runner, set_tracing_disabled + +# Disable the openai-agents SDK's native tracer so it doesn't ship traces to +# api.openai.com using OPENAI_API_KEY (which may be a LiteLLM proxy key). +# SGP tracing below still runs via the Agentex tracing manager. +set_tracing_disabled(True) + from agentex.lib.core.temporal.plugins.openai_agents.hooks.hooks import TemporalStreamingHooks from pydantic import BaseModel from typing import List, Dict, Any @@ -39,6 +45,7 @@ add_tracing_processor_config( SGPTracingProcessorConfig( sgp_api_key=os.environ.get("SGP_API_KEY", ""), sgp_account_id=os.environ.get("SGP_ACCOUNT_ID", ""), + sgp_base_url=os.environ.get("SGP_CLIENT_BASE_URL", ""), ) ) diff --git a/src/agentex/lib/cli/templates/temporal/.env.example.j2 b/src/agentex/lib/cli/templates/temporal/.env.example.j2 index 1e81b15dd..015f49ef7 100644 --- a/src/agentex/lib/cli/templates/temporal/.env.example.j2 +++ b/src/agentex/lib/cli/templates/temporal/.env.example.j2 @@ -10,3 +10,4 @@ LITELLM_API_KEY= # SGP Configuration (optional - for tracing) # SGP_API_KEY= # SGP_ACCOUNT_ID= +# SGP_CLIENT_BASE_URL= diff --git a/src/agentex/lib/core/observability/__init__.py b/src/agentex/lib/core/observability/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/agentex/lib/core/observability/llm_metrics.py b/src/agentex/lib/core/observability/llm_metrics.py new file mode 100644 index 000000000..b15e83824 --- /dev/null +++ b/src/agentex/lib/core/observability/llm_metrics.py @@ -0,0 +1,121 @@ +"""OTel metrics for LLM calls. + +Single source of truth for LLM-call instrumentation across all agentex code +paths — temporal+openai_agents streaming today, sync ACP and the Claude SDK +plugin in future PRs. Centralizing the instrument definitions here means +those follow-ups don't need to redefine the metric names, units, or +description strings; they import ``get_llm_metrics()`` and record values. + +The meter is no-op when the application hasn't configured a ``MeterProvider``, +so importing this module is safe for runtimes that don't use OTel. Instruments +are created lazily on first ``get_llm_metrics()`` call so a ``MeterProvider`` +configured *after* this module is imported still binds correctly. + +Cardinality is bounded: +- All metrics carry only ``model`` (the LLM model name). +- ``requests`` additionally carries ``status``, drawn from a small fixed set + (see ``classify_status``). + +Resource attributes (``service.name``, ``k8s.*``, etc.) come from the +application's OTel resource configuration and are added to every series +automatically. +""" + +from __future__ import annotations + +from typing import Optional + +from opentelemetry import metrics + + +class LLMMetrics: + """Lazily-created OTel instruments for LLM call telemetry.""" + + def __init__(self) -> None: + meter = metrics.get_meter("agentex.llm") + self.requests = meter.create_counter( + name="agentex.llm.requests", + unit="1", + description=( + "LLM call count tagged with status (success / rate_limit / " + "server_error / client_error / timeout / network_error / " + "other_error). Use to alert on 429s, 5xxs, etc." + ), + ) + self.ttft_ms = meter.create_histogram( + name="agentex.llm.ttft", + unit="ms", + description="Time from request submission to first content token (ms)", + ) + # ttat (time-to-first-answering-token) is distinct from ttft for reasoning + # models: ttft fires on the first reasoning chunk (which arrives quickly), + # while ttat fires on the first user-visible answer token (text or tool + # call). For non-reasoning models the two are equal. + self.ttat_ms = meter.create_histogram( + name="agentex.llm.ttat", + unit="ms", + description="Time from request submission to first answering token (text or tool-call delta) — excludes reasoning chunks", + ) + # Note: TPS denominator is the model-generation window + # (last_token_time - first_token_time), not total stream wall time. + # This isolates raw model throughput from event-loop / tool-call latency. + self.tps = meter.create_histogram( + name="agentex.llm.tps", + unit="tokens/s", + description="Output tokens per second over the generation window", + ) + self.input_tokens = meter.create_counter( + name="agentex.llm.input_tokens", + unit="tokens", + description="Total input tokens sent to the LLM", + ) + self.output_tokens = meter.create_counter( + name="agentex.llm.output_tokens", + unit="tokens", + description="Total output tokens returned by the LLM", + ) + self.cached_input_tokens = meter.create_counter( + name="agentex.llm.cached_input_tokens", + unit="tokens", + description="Subset of input tokens served from prompt cache", + ) + self.reasoning_tokens = meter.create_counter( + name="agentex.llm.reasoning_tokens", + unit="tokens", + description="Output tokens spent on reasoning (subset of output_tokens)", + ) + + +_llm_metrics: Optional[LLMMetrics] = None + + +def get_llm_metrics() -> LLMMetrics: + """Return the LLM metrics singleton, creating it on first use.""" + global _llm_metrics + if _llm_metrics is None: + _llm_metrics = LLMMetrics() + return _llm_metrics + + +def classify_status(exc: Optional[BaseException]) -> str: + """Categorize an LLM call's outcome into a small fixed set of status labels. + + A successful call returns ``"success"``. Exceptions are mapped by type name + so we don't depend on a specific provider SDK's exception class hierarchy: + OpenAI, Anthropic, and other providers all use names like ``RateLimitError``, + ``APITimeoutError``, ``InternalServerError``, etc. + """ + if exc is None: + return "success" + name = type(exc).__name__ + if "RateLimit" in name: + return "rate_limit" + if "Timeout" in name: + return "timeout" + if any(s in name for s in ("ServerError", "InternalServer", "ServiceUnavailable", "BadGateway")): + return "server_error" + if "Connection" in name: + return "network_error" + if any(s in name for s in ("BadRequest", "Authentication", "Permission", "NotFound", "Conflict", "UnprocessableEntity")): + return "client_error" + return "other_error" diff --git a/src/agentex/lib/core/observability/llm_metrics_hooks.py b/src/agentex/lib/core/observability/llm_metrics_hooks.py new file mode 100644 index 000000000..fce4b29ba --- /dev/null +++ b/src/agentex/lib/core/observability/llm_metrics_hooks.py @@ -0,0 +1,57 @@ +"""``RunHooks`` adapter that emits per-call LLM metrics. + +Used by the sync ACP path and as a base class for ``TemporalStreamingHooks`` +on the async path, so token / request / cache metrics emit consistently +across both. Streaming-only metrics (ttft, ttat, tps) are emitted from the +streaming model itself, not here — hooks don't see individual chunks. +""" + +from __future__ import annotations + +from typing import Any +from typing_extensions import override + +from agents import Agent, RunHooks, ModelResponse, RunContextWrapper + +from agentex.lib.core.observability.llm_metrics import classify_status, get_llm_metrics + + +class LLMMetricsHooks(RunHooks): + """Emits ``agentex.llm.requests`` + token counters on every LLM call.""" + + @override + async def on_llm_end( + self, + context: RunContextWrapper[Any], + agent: Agent[Any], + response: ModelResponse, + ) -> None: + del context # part of the RunHooks contract; unused here + m = get_llm_metrics() + attrs = {"model": str(agent.model) if agent.model else "unknown"} + # Request counter only depends on agent.model, so emit it first and + # outside the usage-extraction try block. Token counters reach into + # nested optional fields and are best-effort: a non-OpenAI provider + # (litellm-routed Anthropic, etc.) may return a Usage shape missing + # input_tokens_details / output_tokens_details — we emit zeros where + # we can and skip the rest rather than crash the caller. + try: + m.requests.add(1, {**attrs, "status": "success"}) + except Exception: + pass + try: + usage = response.usage + m.input_tokens.add(usage.input_tokens or 0, attrs) + m.output_tokens.add(usage.output_tokens or 0, attrs) + m.cached_input_tokens.add(usage.input_tokens_details.cached_tokens or 0, attrs) + m.reasoning_tokens.add(usage.output_tokens_details.reasoning_tokens or 0, attrs) + except Exception: + pass + + +def record_llm_failure(model: str, exc: BaseException) -> None: + """Best-effort counter bump for an LLM call that raised before ``on_llm_end``.""" + try: + get_llm_metrics().requests.add(1, {"model": model, "status": classify_status(exc)}) + except Exception: + pass diff --git a/src/agentex/lib/core/observability/tests/__init__.py b/src/agentex/lib/core/observability/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/agentex/lib/core/observability/tests/test_llm_metrics.py b/src/agentex/lib/core/observability/tests/test_llm_metrics.py new file mode 100644 index 000000000..d8ab62eba --- /dev/null +++ b/src/agentex/lib/core/observability/tests/test_llm_metrics.py @@ -0,0 +1,83 @@ +"""Tests for ``agentex.lib.core.observability.llm_metrics``.""" + +from __future__ import annotations + +import agentex.lib.core.observability.llm_metrics as llm_metrics +from agentex.lib.core.observability.llm_metrics import ( + LLMMetrics, + classify_status, + get_llm_metrics, +) + + +class TestClassifyStatus: + def test_none_is_success(self): + assert classify_status(None) == "success" + + def test_rate_limit(self): + class RateLimitError(Exception): + pass + + assert classify_status(RateLimitError()) == "rate_limit" + + def test_timeout(self): + class APITimeoutError(Exception): + pass + + assert classify_status(APITimeoutError()) == "timeout" + + def test_server_error(self): + class InternalServerError(Exception): + pass + + assert classify_status(InternalServerError()) == "server_error" + + class ServiceUnavailable(Exception): + pass + + assert classify_status(ServiceUnavailable()) == "server_error" + + def test_network_error(self): + class APIConnectionError(Exception): + pass + + assert classify_status(APIConnectionError()) == "network_error" + + def test_client_error(self): + for cls_name in ("BadRequestError", "AuthenticationError", "PermissionError"): + cls = type(cls_name, (Exception,), {}) + assert classify_status(cls()) == "client_error" + + def test_unknown_falls_back(self): + class WeirdProviderException(Exception): + pass + + assert classify_status(WeirdProviderException()) == "other_error" + + +class TestGetLLMMetrics: + def test_returns_llm_metrics_instance(self, monkeypatch): + monkeypatch.setattr(llm_metrics, "_llm_metrics", None) + m = get_llm_metrics() + assert isinstance(m, LLMMetrics) + + def test_singleton_returns_same_instance(self, monkeypatch): + monkeypatch.setattr(llm_metrics, "_llm_metrics", None) + first = get_llm_metrics() + second = get_llm_metrics() + assert first is second + + def test_instruments_exist(self, monkeypatch): + monkeypatch.setattr(llm_metrics, "_llm_metrics", None) + m = get_llm_metrics() + for name in ( + "requests", + "ttft_ms", + "ttat_ms", + "tps", + "input_tokens", + "output_tokens", + "cached_input_tokens", + "reasoning_tokens", + ): + assert hasattr(m, name), f"missing instrument: {name}" diff --git a/src/agentex/lib/core/observability/tests/test_llm_metrics_hooks.py b/src/agentex/lib/core/observability/tests/test_llm_metrics_hooks.py new file mode 100644 index 000000000..a2cef95b8 --- /dev/null +++ b/src/agentex/lib/core/observability/tests/test_llm_metrics_hooks.py @@ -0,0 +1,215 @@ +"""Tests for ``agentex.lib.core.observability.llm_metrics_hooks``.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +import agentex.lib.core.observability.llm_metrics_hooks as hooks_module +from agentex.lib.core.observability.llm_metrics_hooks import ( + LLMMetricsHooks, + record_llm_failure, +) + + +def _mock_response( + *, + input_tokens: int = 100, + output_tokens: int = 50, + cached_tokens: int = 30, + reasoning_tokens: int = 10, +) -> MagicMock: + response = MagicMock() + response.usage.input_tokens = input_tokens + response.usage.output_tokens = output_tokens + response.usage.input_tokens_details.cached_tokens = cached_tokens + response.usage.output_tokens_details.reasoning_tokens = reasoning_tokens + return response + + +def _mock_agent(model: str = "gpt-5") -> MagicMock: + agent = MagicMock() + agent.model = model + return agent + + +class TestLLMMetricsHooksOnLLMEnd: + @pytest.mark.asyncio + async def test_emits_success_request_counter(self, monkeypatch): + m = MagicMock() + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + await LLMMetricsHooks().on_llm_end( + context=MagicMock(), + agent=_mock_agent("gpt-5"), + response=_mock_response(), + ) + + m.requests.add.assert_called_once_with(1, {"model": "gpt-5", "status": "success"}) + + @pytest.mark.asyncio + async def test_emits_token_counters(self, monkeypatch): + m = MagicMock() + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + await LLMMetricsHooks().on_llm_end( + context=MagicMock(), + agent=_mock_agent("gpt-5"), + response=_mock_response( + input_tokens=200, + output_tokens=75, + cached_tokens=50, + reasoning_tokens=20, + ), + ) + + attrs = {"model": "gpt-5"} + m.input_tokens.add.assert_called_once_with(200, attrs) + m.output_tokens.add.assert_called_once_with(75, attrs) + m.cached_input_tokens.add.assert_called_once_with(50, attrs) + m.reasoning_tokens.add.assert_called_once_with(20, attrs) + + @pytest.mark.asyncio + async def test_zero_tokens_emit_zero_not_skip(self, monkeypatch): + m = MagicMock() + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + await LLMMetricsHooks().on_llm_end( + context=MagicMock(), + agent=_mock_agent(), + response=_mock_response(input_tokens=0, output_tokens=0, cached_tokens=0, reasoning_tokens=0), + ) + + m.input_tokens.add.assert_called_once_with(0, {"model": "gpt-5"}) + m.output_tokens.add.assert_called_once_with(0, {"model": "gpt-5"}) + + @pytest.mark.asyncio + async def test_unknown_model_falls_back(self, monkeypatch): + m = MagicMock() + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + agent = MagicMock() + agent.model = None + + await LLMMetricsHooks().on_llm_end( + context=MagicMock(), + agent=agent, + response=_mock_response(), + ) + + m.requests.add.assert_called_once_with(1, {"model": "unknown", "status": "success"}) + + @pytest.mark.asyncio + async def test_swallows_exporter_failure(self, monkeypatch): + m = MagicMock() + m.requests.add.side_effect = RuntimeError("exporter exploded") + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + # Should not raise — caller's flow must not break on metric failure. + await LLMMetricsHooks().on_llm_end( + context=MagicMock(), + agent=_mock_agent(), + response=_mock_response(), + ) + + @pytest.mark.asyncio + async def test_missing_usage_still_emits_request_counter(self, monkeypatch): + """Provider returns a response without `usage` — caller shouldn't crash, + and we should still record the success request counter.""" + m = MagicMock() + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + class _Response: + @property + def usage(self): + raise AttributeError("no usage") + + await LLMMetricsHooks().on_llm_end( + context=MagicMock(), + agent=_mock_agent(), + response=_Response(), # type: ignore[arg-type] + ) + + m.requests.add.assert_called_once_with(1, {"model": "gpt-5", "status": "success"}) + m.input_tokens.add.assert_not_called() + m.output_tokens.add.assert_not_called() + + @pytest.mark.asyncio + async def test_missing_token_details_skips_those_counters(self, monkeypatch): + """Provider returns Usage without input_tokens_details (e.g. some + litellm wrappers / non-OpenAI providers): top-level token counts + still emit; the nested cached/reasoning counters are skipped.""" + m = MagicMock() + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + class _Usage: + input_tokens = 100 + output_tokens = 50 + + @property + def input_tokens_details(self): + raise AttributeError("no details") + + class _Response: + usage = _Usage() + + await LLMMetricsHooks().on_llm_end( + context=MagicMock(), + agent=_mock_agent(), + response=_Response(), # type: ignore[arg-type] + ) + + # Request counter still fires (it's outside the usage-extraction try). + m.requests.add.assert_called_once_with(1, {"model": "gpt-5", "status": "success"}) + # input_tokens.add fires before the nested attribute access. + m.input_tokens.add.assert_called_once_with(100, {"model": "gpt-5"}) + # cached_input_tokens / reasoning_tokens skipped — the AttributeError + # bailed before they could be called. + m.cached_input_tokens.add.assert_not_called() + m.reasoning_tokens.add.assert_not_called() + + @pytest.mark.asyncio + async def test_none_token_values_emit_as_zero(self, monkeypatch): + """Some providers report None instead of 0 for fields they don't track.""" + m = MagicMock() + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + response = MagicMock() + response.usage.input_tokens = None + response.usage.output_tokens = None + response.usage.input_tokens_details.cached_tokens = None + response.usage.output_tokens_details.reasoning_tokens = None + + await LLMMetricsHooks().on_llm_end( + context=MagicMock(), + agent=_mock_agent(), + response=response, + ) + + attrs = {"model": "gpt-5"} + m.input_tokens.add.assert_called_once_with(0, attrs) + m.output_tokens.add.assert_called_once_with(0, attrs) + m.cached_input_tokens.add.assert_called_once_with(0, attrs) + m.reasoning_tokens.add.assert_called_once_with(0, attrs) + + +class TestRecordLLMFailure: + def test_emits_classified_status(self, monkeypatch): + m = MagicMock() + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + class RateLimitError(Exception): + pass + + record_llm_failure("gpt-5", RateLimitError()) + + m.requests.add.assert_called_once_with(1, {"model": "gpt-5", "status": "rate_limit"}) + + def test_swallows_exporter_failure(self, monkeypatch): + m = MagicMock() + m.requests.add.side_effect = RuntimeError("exporter exploded") + monkeypatch.setattr(hooks_module, "get_llm_metrics", lambda: m) + + # Should not raise. + record_llm_failure("gpt-5", Exception("upstream")) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py index cc27006fc..758b0db27 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py @@ -8,18 +8,19 @@ from typing import Any, override from datetime import timedelta -from agents import Tool, Agent, RunHooks, RunContextWrapper +from agents import Tool, Agent, RunContextWrapper from temporalio import workflow from agents.tool_context import ToolContext from agentex.types.text_content import TextContent from agentex.types.task_message_content import ToolRequestContent, ToolResponseContent +from agentex.lib.core.observability.llm_metrics_hooks import LLMMetricsHooks from agentex.lib.core.temporal.plugins.openai_agents.hooks.activities import stream_lifecycle_content logger = logging.getLogger(__name__) -class TemporalStreamingHooks(RunHooks): +class TemporalStreamingHooks(LLMMetricsHooks): """Convenience hooks class for streaming OpenAI Agent lifecycle events to the AgentEx UI. This class automatically streams agent lifecycle events (tool calls, handoffs) to the diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index 4f18ae379..7ccc6627a 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -1,6 +1,7 @@ """Custom Temporal Model Provider with streaming support for OpenAI agents.""" from __future__ import annotations +import time import uuid from typing import Any, List, Union, Optional, override @@ -31,6 +32,8 @@ # Re-export the canonical StreamingMode literal from the streaming service so # all layers share a single definition. from agentex.lib.core.services.adk.streaming import StreamingMode as StreamingMode +from agentex.lib.core.observability.llm_metrics import get_llm_metrics +from agentex.lib.core.observability.llm_metrics_hooks import record_llm_failure try: from agents.tool import ShellTool # type: ignore[attr-defined] @@ -78,6 +81,11 @@ logger = make_logger("agentex.temporal.streaming") +# LLM metrics live in agentex.lib.core.observability.llm_metrics so other +# code paths (sync ACP, Claude SDK plugin, future provider integrations) +# can share the same instrument definitions without redefining names. + + def _serialize_item(item: Any) -> dict[str, Any]: """ Universal serializer for any item type from OpenAI Agents SDK. @@ -592,7 +600,11 @@ async def get_response( # endpoints recognize this parameter, so we don't auto-inject a default. prompt_cache_key = extra_args.pop("prompt_cache_key", NOT_GIVEN) - # Create the response stream using Responses API + # Create the response stream using Responses API. + # Bookmark request start *before* the await so ttft captures the full + # user-perceived latency (HTTP round-trip + model TTFB), not just the + # post-connect event-loop delay. + stream_start_perf = time.perf_counter() logger.debug(f"[TemporalStreamingModel] Creating response stream with Responses API") stream = await self.client.responses.create( # type: ignore[call-overload] @@ -642,6 +654,16 @@ async def get_response( reasoning_summaries = [] reasoning_contents = [] event_count = 0 + # ttft / ttat / tps instrumentation. ``stream_start_perf`` is set + # above, before the responses.create() await, so it captures the full + # request-to-first-token latency. ``first_token_at`` and + # ``last_token_at`` bracket the model-generation window for tps. + # ``first_answer_at`` is set on the first user-visible answer token + # (text or tool-call delta) and excludes reasoning chunks, so ttat + # measures the latency users actually perceive on reasoning models. + first_token_at: Optional[float] = None + last_token_at: Optional[float] = None + first_answer_at: Optional[float] = None # We expect task_id to always be provided for streaming if not task_id: @@ -656,6 +678,28 @@ async def get_response( # Log event type logger.debug(f"[TemporalStreamingModel] Event {event_count}: {type(event).__name__}") + # Bookmark first/last token-producing events for ttft and tps. + # Includes function-call argument deltas so the generation window + # covers every event type whose tokens land in usage.output_tokens. + if isinstance(event, ( + ResponseTextDeltaEvent, + ResponseReasoningTextDeltaEvent, + ResponseReasoningSummaryTextDeltaEvent, + ResponseFunctionCallArgumentsDeltaEvent, + )): + now_perf = time.perf_counter() + if first_token_at is None: + first_token_at = now_perf + last_token_at = now_perf + # ttat: first user-visible answer token (text or tool call), + # excluding reasoning chunks. Equal to ttft for non-reasoning + # models; differs by reasoning duration for reasoning models. + if first_answer_at is None and isinstance(event, ( + ResponseTextDeltaEvent, + ResponseFunctionCallArgumentsDeltaEvent, + )): + first_answer_at = now_perf + # Handle different event types using isinstance for type safety if isinstance(event, ResponseOutputItemAddedEvent): # New output item (reasoning, function call, or message) @@ -983,6 +1027,25 @@ async def get_response( span.output = output_data + # Streaming-only metrics. Token counters and the success request + # counter are emitted by LLMMetricsHooks.on_llm_end so they fire + # consistently across streaming and non-streaming paths. + m = get_llm_metrics() + metric_attrs = {"model": self.model_name} + if first_token_at is not None: + m.ttft_ms.record((first_token_at - stream_start_perf) * 1000, metric_attrs) + if first_answer_at is not None: + m.ttat_ms.record((first_answer_at - stream_start_perf) * 1000, metric_attrs) + # Single-token responses collapse the generation window to 0; tps + # is undefined and skipped. + if ( + first_token_at is not None + and last_token_at is not None + and last_token_at > first_token_at + and (usage.output_tokens or 0) > 0 + ): + m.tps.record(usage.output_tokens / (last_token_at - first_token_at), metric_attrs) + # Return the response. response_id is the server-issued id from # ResponseCompletedEvent.response.id, or None when the stream ended # without a completed event (error path) — matching the documented @@ -998,6 +1061,10 @@ async def get_response( except Exception as e: logger.error(f"Error using Responses API: {e}") + # LLMMetricsHooks.on_llm_end doesn't fire on error, so emit the + # failure counter here. Best-effort so the typed LLM exception + # always propagates intact for retry / circuit-breaker logic. + record_llm_failure(self.model_name, e) raise # The _get_response_with_responses_api method has been merged into get_response above diff --git a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py index 187dedcbc..3a1c96c1b 100644 --- a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import override +from typing import cast, override import scale_gp_beta.lib.tracing as tracing from scale_gp_beta import SGPClient, AsyncSGPClient @@ -27,6 +27,39 @@ def _get_span_type(span: Span) -> str: return "STANDALONE" +def _add_source_to_span(span: Span, env_vars: EnvironmentVariables) -> None: + if span.data is None: + span.data = {} + if isinstance(span.data, dict): + span.data["__source__"] = "agentex" + if env_vars.ACP_TYPE is not None: + span.data["__acp_type__"] = env_vars.ACP_TYPE + if env_vars.AGENT_NAME is not None: + span.data["__agent_name__"] = env_vars.AGENT_NAME + if env_vars.AGENT_ID is not None: + span.data["__agent_id__"] = env_vars.AGENT_ID + + +def _build_sgp_span(span: Span, env_vars: EnvironmentVariables) -> SGPSpan: + """Build an SGPSpan from an agentex Span. Idempotent on span_id at the SGP backend.""" + _add_source_to_span(span, env_vars) + sgp_span = cast( + SGPSpan, + create_span( + name=span.name, + span_type=_get_span_type(span), + span_id=span.id, + parent_id=span.parent_id, + trace_id=span.trace_id, + input=span.input, + output=span.output, + metadata=span.data, + ), + ) + sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] + return sgp_span + + class SGPSyncTracingProcessor(SyncTracingProcessor): def __init__(self, config: SGPTracingProcessorConfig): disabled = config.sgp_api_key == "" or config.sgp_account_id == "" @@ -38,63 +71,27 @@ def __init__(self, config: SGPTracingProcessorConfig): ), disabled=disabled, ) - self._spans: dict[str, SGPSpan] = {} self.env_vars = EnvironmentVariables.refresh() - def _add_source_to_span(self, span: Span) -> None: - if span.data is None: - span.data = {} - if isinstance(span.data, dict): - span.data["__source__"] = "agentex" - if self.env_vars.ACP_TYPE is not None: - span.data["__acp_type__"] = self.env_vars.ACP_TYPE - if self.env_vars.AGENT_NAME is not None: - span.data["__agent_name__"] = self.env_vars.AGENT_NAME - if self.env_vars.AGENT_ID is not None: - span.data["__agent_id__"] = self.env_vars.AGENT_ID - @override def on_span_start(self, span: Span) -> None: - self._add_source_to_span(span) - - sgp_span = create_span( - name=span.name, - span_type=_get_span_type(span), - span_id=span.id, - parent_id=span.parent_id, - trace_id=span.trace_id, - input=span.input, - output=span.output, - metadata=span.data, - ) - sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] + sgp_span = _build_sgp_span(span, self.env_vars) sgp_span.flush(blocking=False) - self._spans[span.id] = sgp_span - @override def on_span_end(self, span: Span) -> None: - sgp_span = self._spans.pop(span.id, None) - if sgp_span is None: - logger.warning(f"Span {span.id} not found in stored spans, skipping span end") - return - - self._add_source_to_span(span) - sgp_span.output = span.output # type: ignore[assignment] - sgp_span.metadata = span.data # type: ignore[assignment] + sgp_span = _build_sgp_span(span, self.env_vars) sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] sgp_span.flush(blocking=False) @override def shutdown(self) -> None: - self._spans.clear() flush_queue() class SGPAsyncTracingProcessor(AsyncTracingProcessor): def __init__(self, config: SGPTracingProcessorConfig): self.disabled = config.sgp_api_key == "" or config.sgp_account_id == "" - self._spans: dict[str, SGPSpan] = {} import httpx # Disable keepalive so each HTTP call gets a fresh TCP connection, @@ -113,18 +110,6 @@ def __init__(self, config: SGPTracingProcessorConfig): ) self.env_vars = EnvironmentVariables.refresh() - def _add_source_to_span(self, span: Span) -> None: - if span.data is None: - span.data = {} - if isinstance(span.data, dict): - span.data["__source__"] = "agentex" - if self.env_vars.ACP_TYPE is not None: - span.data["__acp_type__"] = self.env_vars.ACP_TYPE - if self.env_vars.AGENT_NAME is not None: - span.data["__agent_name__"] = self.env_vars.AGENT_NAME - if self.env_vars.AGENT_ID is not None: - span.data["__agent_id__"] = self.env_vars.AGENT_ID - @override async def on_span_start(self, span: Span) -> None: await self.on_spans_start([span]) @@ -138,22 +123,7 @@ async def on_spans_start(self, spans: list[Span]) -> None: if not spans: return - sgp_spans: list[SGPSpan] = [] - for span in spans: - self._add_source_to_span(span) - sgp_span = create_span( - name=span.name, - span_type=_get_span_type(span), - span_id=span.id, - parent_id=span.parent_id, - trace_id=span.trace_id, - input=span.input, - output=span.output, - metadata=span.data, - ) - sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] - self._spans[span.id] = sgp_span - sgp_spans.append(sgp_span) + sgp_spans = [_build_sgp_span(span, self.env_vars) for span in spans] if self.disabled: logger.warning("SGP is disabled, skipping span upsert") @@ -167,29 +137,18 @@ async def on_spans_end(self, spans: list[Span]) -> None: if not spans: return - to_upsert: list[SGPSpan] = [] + sgp_spans: list[SGPSpan] = [] for span in spans: - sgp_span = self._spans.pop(span.id, None) - if sgp_span is None: - logger.warning(f"Span {span.id} not found in stored spans, skipping span end") - continue - - self._add_source_to_span(span) - sgp_span.input = span.input # type: ignore[assignment] - sgp_span.output = span.output # type: ignore[assignment] - sgp_span.metadata = span.data # type: ignore[assignment] + sgp_span = _build_sgp_span(span, self.env_vars) sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] - to_upsert.append(sgp_span) + sgp_spans.append(sgp_span) - if self.disabled or not to_upsert: + if self.disabled: return await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[s.to_request_params() for s in to_upsert] + items=[s.to_request_params() for s in sgp_spans] ) @override async def shutdown(self) -> None: - await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[sgp_span.to_request_params() for sgp_span in self._spans.values()] - ) - self._spans.clear() + pass diff --git a/src/agentex/lib/core/tracing/tracing_processor_manager.py b/src/agentex/lib/core/tracing/tracing_processor_manager.py index 14b0ce39b..07c440313 100644 --- a/src/agentex/lib/core/tracing/tracing_processor_manager.py +++ b/src/agentex/lib/core/tracing/tracing_processor_manager.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from threading import Lock -from agentex.lib.types.tracing import TracingProcessorConfig, AgentexTracingProcessorConfig +from agentex.lib.types.tracing import TracingProcessorConfig from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( SGPSyncTracingProcessor, SGPAsyncTracingProcessor, @@ -73,22 +73,8 @@ def get_async_processors(self) -> list[AsyncTracingProcessor]: add_tracing_processor_config = GLOBAL_TRACING_PROCESSOR_MANAGER.add_processor_config set_tracing_processor_configs = GLOBAL_TRACING_PROCESSOR_MANAGER.set_processor_configs -# Lazy initialization to avoid circular imports -_default_initialized = False - -def _ensure_default_initialized(): - """Ensure default processor is initialized (lazy to avoid circular imports).""" - global _default_initialized - if not _default_initialized: - add_tracing_processor_config(AgentexTracingProcessorConfig()) - _default_initialized = True - def get_sync_tracing_processors(): - """Get sync processors, initializing defaults if needed.""" - _ensure_default_initialized() return GLOBAL_TRACING_PROCESSOR_MANAGER.get_sync_processors() def get_async_tracing_processors(): - """Get async processors, initializing defaults if needed.""" - _ensure_default_initialized() return GLOBAL_TRACING_PROCESSOR_MANAGER.get_async_processors() diff --git a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py index 50d615e0d..4614fe540 100644 --- a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py +++ b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py @@ -41,18 +41,16 @@ def _make_mock_sgp_span() -> MagicMock: # --------------------------------------------------------------------------- -class TestSGPSyncTracingProcessorMemoryLeak: +class TestSGPSyncTracingProcessor: @staticmethod def _make_processor(): mock_env = MagicMock() mock_env.refresh.return_value = MagicMock(ACP_TYPE=None, AGENT_NAME=None, AGENT_ID=None) mock_create_span = MagicMock(side_effect=lambda **kwargs: _make_mock_sgp_span()) - with patch(f"{MODULE}.EnvironmentVariables", mock_env), \ - patch(f"{MODULE}.SGPClient"), \ - patch(f"{MODULE}.tracing"), \ - patch(f"{MODULE}.flush_queue"), \ - patch(f"{MODULE}.create_span", mock_create_span): + with patch(f"{MODULE}.EnvironmentVariables", mock_env), patch(f"{MODULE}.SGPClient"), patch( + f"{MODULE}.tracing" + ), patch(f"{MODULE}.flush_queue"), patch(f"{MODULE}.create_span", mock_create_span): from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( SGPSyncTracingProcessor, ) @@ -61,41 +59,50 @@ def _make_processor(): return processor, mock_create_span - def test_spans_not_leaked_after_completed_lifecycle(self): + def test_processor_holds_no_per_span_state(self): + """Stateless processor must not retain any per-span dict between lifecycle events.""" processor, _ = self._make_processor() + assert not hasattr(processor, "_spans") - with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + def test_span_lifecycle_produces_two_flushes(self): + """Each span produces one flush on start and one on end.""" + processor, _ = self._make_processor() + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()) as mock_cs: for _ in range(100): span = _make_span() processor.on_span_start(span) span.end_time = datetime.now(UTC) processor.on_span_end(span) - assert len(processor._spans) == 0, ( - f"Expected 0 spans after 100 complete lifecycles, got {len(processor._spans)} — memory leak!" - ) + # 100 spans × (1 start + 1 end) = 200 build calls. + assert mock_cs.call_count == 200 + + def test_span_end_without_prior_start_still_flushes(self): + """Cross-pod Temporal case: END activity lands on a pod that never saw START. - def test_spans_present_during_active_lifecycle(self): + Today this used to be a silent no-op. After the stateless refactor it + must still flush a complete span (start_time + end_time + payload). + """ processor, _ = self._make_processor() - with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): - span = _make_span() - processor.on_span_start(span) - assert len(processor._spans) == 1, "Span should be tracked while active" + captured_spans: list[MagicMock] = [] + def capture_create_span(**kwargs): + sgp_span = _make_mock_sgp_span() + captured_spans.append(sgp_span) + return sgp_span + + with patch(f"{MODULE}.create_span", side_effect=capture_create_span): + span = _make_span() span.end_time = datetime.now(UTC) + # No on_span_start — END lands here for the first time. processor.on_span_end(span) - assert len(processor._spans) == 0, "Span should be removed after end" - def test_span_end_for_unknown_span_is_noop(self): - processor, _ = self._make_processor() - - span = _make_span() - # End a span that was never started — should not raise - span.end_time = datetime.now(UTC) - processor.on_span_end(span) - - assert len(processor._spans) == 0 + assert len(captured_spans) == 1 + assert captured_spans[0].flush.called + assert captured_spans[0].start_time is not None + assert captured_spans[0].end_time is not None # --------------------------------------------------------------------------- @@ -103,7 +110,7 @@ def test_span_end_for_unknown_span_is_noop(self): # --------------------------------------------------------------------------- -class TestSGPAsyncTracingProcessorMemoryLeak: +class TestSGPAsyncTracingProcessor: @staticmethod def _make_processor(): mock_env = MagicMock() @@ -113,9 +120,9 @@ def _make_processor(): mock_async_client = MagicMock() mock_async_client.spans.upsert_batch = AsyncMock() - with patch(f"{MODULE}.EnvironmentVariables", mock_env), \ - patch(f"{MODULE}.create_span", mock_create_span), \ - patch(f"{MODULE}.AsyncSGPClient", return_value=mock_async_client): + with patch(f"{MODULE}.EnvironmentVariables", mock_env), patch(f"{MODULE}.create_span", mock_create_span), patch( + f"{MODULE}.AsyncSGPClient", return_value=mock_async_client + ): from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( SGPAsyncTracingProcessor, ) @@ -125,69 +132,78 @@ def _make_processor(): # Wire up the mock client after construction (constructor stores it) processor.sgp_async_client = mock_async_client - # Keep create_span mock active for on_span_start calls return processor, mock_create_span - async def test_spans_not_leaked_after_completed_lifecycle(self): + def test_processor_holds_no_per_span_state(self): + """Stateless processor must not retain any per-span dict between lifecycle events.""" processor, _ = self._make_processor() + assert not hasattr(processor, "_spans") - with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): - for _ in range(100): - span = _make_span() - await processor.on_span_start(span) - span.end_time = datetime.now(UTC) - await processor.on_span_end(span) - - assert len(processor._spans) == 0, ( - f"Expected 0 spans after 100 complete lifecycles, got {len(processor._spans)} — memory leak!" - ) - - async def test_spans_present_during_active_lifecycle(self): + async def test_span_lifecycle_produces_two_upserts(self): + """Each span produces one upsert_batch call on start and one on end.""" processor, _ = self._make_processor() with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): span = _make_span() await processor.on_span_start(span) - assert len(processor._spans) == 1, "Span should be tracked while active" - span.end_time = datetime.now(UTC) await processor.on_span_end(span) - assert len(processor._spans) == 0, "Span should be removed after end" - async def test_span_end_for_unknown_span_is_noop(self): + assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 + + async def test_span_end_without_prior_start_still_upserts(self): + """Cross-pod Temporal case: END activity lands on a pod that never saw START. + + Today this used to be a silent no-op. After the stateless refactor it + must still upsert a complete span via upsert_batch. + """ processor, _ = self._make_processor() - span = _make_span() - span.end_time = datetime.now(UTC) - await processor.on_span_end(span) + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + span = _make_span() + span.end_time = datetime.now(UTC) + # No on_span_start — END lands here for the first time. + await processor.on_span_end(span) - assert len(processor._spans) == 0 + assert processor.sgp_async_client.spans.upsert_batch.call_count == 1 + items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"] + assert len(items) == 1 - async def test_sgp_span_input_updated_on_end(self): - """on_span_end should update sgp_span.input from the incoming span.""" + async def test_sgp_span_input_and_output_propagated_on_end(self): + """on_span_end should send the span's current input and output via upsert_batch.""" processor, _ = self._make_processor() - with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + captured: list[MagicMock] = [] + + def capture_create_span(**kwargs): + sgp_span = _make_mock_sgp_span() + captured.append(sgp_span) + return sgp_span + + mock_create_span = MagicMock(side_effect=capture_create_span) + with patch(f"{MODULE}.create_span", mock_create_span): span = _make_span() span.input = {"messages": [{"role": "user", "content": "hello"}]} await processor.on_span_start(span) - assert len(processor._spans) == 1 - - # Simulate modified input at end time - updated_input: dict[str, object] = {"messages": [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hi"}, - ]} - span.input = updated_input - span.output = {"response": "hi"} - span.end_time = datetime.now(UTC) - await processor.on_span_end(span) - - # Span should be removed after end - assert len(processor._spans) == 0 - # The end upsert should have been called + span.input = { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ] + } + span.output = {"response": "hi"} + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end + # The end-time SGPSpan should have end_time populated. + end_span = captured[-1] + assert end_span.end_time is not None + # Verify the updated input/output reached create_span on the end call. + end_call_kwargs = mock_create_span.call_args_list[-1].kwargs + assert end_call_kwargs["input"]["messages"][-1]["role"] == "assistant" + assert end_call_kwargs["output"] == {"response": "hi"} async def test_on_spans_start_sends_single_upsert_for_batch(self): """Given N spans at once, on_spans_start should make ONE upsert_batch HTTP call.""" @@ -203,8 +219,6 @@ async def test_on_spans_start_sends_single_upsert_for_batch(self): ) items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"] assert len(items) == n - # All spans should be tracked for the subsequent end call - assert len(processor._spans) == n async def test_on_spans_end_sends_single_upsert_for_batch(self): """Given N spans at once, on_spans_end should make ONE upsert_batch HTTP call.""" @@ -226,4 +240,3 @@ async def test_on_spans_end_sends_single_upsert_for_batch(self): ) items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"] assert len(items) == n - assert len(processor._spans) == 0 diff --git a/tests/test_models.py b/tests/test_models.py index 1aef5c605..55ce30876 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,8 @@ import json -from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Union, Iterable, Optional, cast from datetime import datetime, timezone -from typing_extensions import Literal, Annotated, TypeAliasType +from collections import deque +from typing_extensions import Literal, Annotated, TypedDict, TypeAliasType import pytest import pydantic @@ -9,7 +10,7 @@ from agentex._utils import PropertyInfo from agentex._compat import PYDANTIC_V1, parse_obj, model_dump, model_json -from agentex._models import DISCRIMINATOR_CACHE, BaseModel, construct_type +from agentex._models import DISCRIMINATOR_CACHE, BaseModel, EagerIterable, construct_type class BasicModel(BaseModel): @@ -961,3 +962,56 @@ def __getattr__(self, attr: str) -> Item: ... assert model.a.prop == 1 assert isinstance(model.a, Item) assert model.other == "foo" + + +# NOTE: Workaround for Pydantic Iterable behavior. +# Iterable fields are replaced with a ValidatorIterator and may be consumed +# during serialization, which can cause subsequent dumps to return empty data. +# See: https://github.com/pydantic/pydantic/issues/9541 +@pytest.mark.parametrize( + "data, expected_validated", + [ + ([1, 2, 3], [1, 2, 3]), + ((1, 2, 3), (1, 2, 3)), + (set([1, 2, 3]), set([1, 2, 3])), + (iter([1, 2, 3]), [1, 2, 3]), + ([], []), + ((x for x in [1, 2, 3]), [1, 2, 3]), + (map(lambda x: x, [1, 2, 3]), [1, 2, 3]), + (frozenset([1, 2, 3]), frozenset([1, 2, 3])), + (deque([1, 2, 3]), deque([1, 2, 3])), + ], + ids=["list", "tuple", "set", "iterator", "empty", "generator", "map", "frozenset", "deque"], +) +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2") +def test_iterable_construction(data: Iterable[int], expected_validated: Iterable[int]) -> None: + class TypeWithIterable(TypedDict): + items: EagerIterable[int] + + class Model(BaseModel): + data: TypeWithIterable + + m = Model.model_validate({"data": {"items": data}}) + assert m.data["items"] == expected_validated + + # Verify repeated dumps don't lose data (the original bug) + assert m.model_dump()["data"]["items"] == list(expected_validated) + assert m.model_dump()["data"]["items"] == list(expected_validated) + + +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2") +def test_iterable_construction_str_falls_back_to_list() -> None: + # str is iterable (over chars), but str(list_of_chars) produces the list's repr + # rather than reconstructing a string from items. We special-case str to fall + # back to list instead of attempting reconstruction. + class TypeWithIterable(TypedDict): + items: EagerIterable[str] + + class Model(BaseModel): + data: TypeWithIterable + + m = Model.model_validate({"data": {"items": "hello"}}) + + # falls back to list of chars rather than calling str(["h", "e", "l", "l", "o"]) + assert m.data["items"] == ["h", "e", "l", "l", "o"] + assert m.model_dump()["data"]["items"] == ["h", "e", "l", "l", "o"]