Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 43 additions & 84 deletions src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import override
from typing import cast, override

import scale_gp_beta.lib.tracing as tracing
from scale_gp_beta import SGPClient, AsyncSGPClient
Expand All @@ -27,6 +27,39 @@ def _get_span_type(span: Span) -> str:
return "STANDALONE"


def _add_source_to_span(span: Span, env_vars: EnvironmentVariables) -> None:
if span.data is None:
span.data = {}
if isinstance(span.data, dict):
span.data["__source__"] = "agentex"
if env_vars.ACP_TYPE is not None:
span.data["__acp_type__"] = env_vars.ACP_TYPE
if env_vars.AGENT_NAME is not None:
span.data["__agent_name__"] = env_vars.AGENT_NAME
if env_vars.AGENT_ID is not None:
span.data["__agent_id__"] = env_vars.AGENT_ID


def _build_sgp_span(span: Span, env_vars: EnvironmentVariables) -> SGPSpan:
"""Build an SGPSpan from an agentex Span. Idempotent on span_id at the SGP backend."""
_add_source_to_span(span, env_vars)
sgp_span = cast(
SGPSpan,
create_span(
name=span.name,
span_type=_get_span_type(span),
span_id=span.id,
parent_id=span.parent_id,
trace_id=span.trace_id,
input=span.input,
output=span.output,
metadata=span.data,
),
)
sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr]
return sgp_span


class SGPSyncTracingProcessor(SyncTracingProcessor):
def __init__(self, config: SGPTracingProcessorConfig):
disabled = config.sgp_api_key == "" or config.sgp_account_id == ""
Expand All @@ -38,63 +71,27 @@ def __init__(self, config: SGPTracingProcessorConfig):
),
disabled=disabled,
)
self._spans: dict[str, SGPSpan] = {}
self.env_vars = EnvironmentVariables.refresh()

def _add_source_to_span(self, span: Span) -> None:
if span.data is None:
span.data = {}
if isinstance(span.data, dict):
span.data["__source__"] = "agentex"
if self.env_vars.ACP_TYPE is not None:
span.data["__acp_type__"] = self.env_vars.ACP_TYPE
if self.env_vars.AGENT_NAME is not None:
span.data["__agent_name__"] = self.env_vars.AGENT_NAME
if self.env_vars.AGENT_ID is not None:
span.data["__agent_id__"] = self.env_vars.AGENT_ID

@override
def on_span_start(self, span: Span) -> None:
self._add_source_to_span(span)

sgp_span = create_span(
name=span.name,
span_type=_get_span_type(span),
span_id=span.id,
parent_id=span.parent_id,
trace_id=span.trace_id,
input=span.input,
output=span.output,
metadata=span.data,
)
sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr]
sgp_span = _build_sgp_span(span, self.env_vars)
sgp_span.flush(blocking=False)

self._spans[span.id] = sgp_span

@override
def on_span_end(self, span: Span) -> None:
sgp_span = self._spans.pop(span.id, None)
if sgp_span is None:
logger.warning(f"Span {span.id} not found in stored spans, skipping span end")
return

self._add_source_to_span(span)
sgp_span.output = span.output # type: ignore[assignment]
sgp_span.metadata = span.data # type: ignore[assignment]
sgp_span = _build_sgp_span(span, self.env_vars)
sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr]
sgp_span.flush(blocking=False)

@override
def shutdown(self) -> None:
self._spans.clear()
flush_queue()


class SGPAsyncTracingProcessor(AsyncTracingProcessor):
def __init__(self, config: SGPTracingProcessorConfig):
self.disabled = config.sgp_api_key == "" or config.sgp_account_id == ""
self._spans: dict[str, SGPSpan] = {}
import httpx

# Disable keepalive so each HTTP call gets a fresh TCP connection,
Expand All @@ -113,18 +110,6 @@ def __init__(self, config: SGPTracingProcessorConfig):
)
self.env_vars = EnvironmentVariables.refresh()

def _add_source_to_span(self, span: Span) -> None:
if span.data is None:
span.data = {}
if isinstance(span.data, dict):
span.data["__source__"] = "agentex"
if self.env_vars.ACP_TYPE is not None:
span.data["__acp_type__"] = self.env_vars.ACP_TYPE
if self.env_vars.AGENT_NAME is not None:
span.data["__agent_name__"] = self.env_vars.AGENT_NAME
if self.env_vars.AGENT_ID is not None:
span.data["__agent_id__"] = self.env_vars.AGENT_ID

@override
async def on_span_start(self, span: Span) -> None:
await self.on_spans_start([span])
Expand All @@ -138,22 +123,7 @@ async def on_spans_start(self, spans: list[Span]) -> None:
if not spans:
return

sgp_spans: list[SGPSpan] = []
for span in spans:
self._add_source_to_span(span)
sgp_span = create_span(
name=span.name,
span_type=_get_span_type(span),
span_id=span.id,
parent_id=span.parent_id,
trace_id=span.trace_id,
input=span.input,
output=span.output,
metadata=span.data,
)
sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr]
self._spans[span.id] = sgp_span
sgp_spans.append(sgp_span)
sgp_spans = [_build_sgp_span(span, self.env_vars) for span in spans]

if self.disabled:
logger.warning("SGP is disabled, skipping span upsert")
Expand All @@ -167,29 +137,18 @@ async def on_spans_end(self, spans: list[Span]) -> None:
if not spans:
return

to_upsert: list[SGPSpan] = []
sgp_spans: list[SGPSpan] = []
for span in spans:
sgp_span = self._spans.pop(span.id, None)
if sgp_span is None:
logger.warning(f"Span {span.id} not found in stored spans, skipping span end")
continue

self._add_source_to_span(span)
sgp_span.input = span.input # type: ignore[assignment]
sgp_span.output = span.output # type: ignore[assignment]
sgp_span.metadata = span.data # type: ignore[assignment]
sgp_span = _build_sgp_span(span, self.env_vars)
sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr]
to_upsert.append(sgp_span)
sgp_spans.append(sgp_span)

if self.disabled or not to_upsert:
if self.disabled:
return
await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr]
items=[s.to_request_params() for s in to_upsert]
items=[s.to_request_params() for s in sgp_spans]
)

@override
async def shutdown(self) -> None:
await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr]
items=[sgp_span.to_request_params() for sgp_span in self._spans.values()]
)
self._spans.clear()
pass
Loading
Loading