diff --git a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py index 187dedcbc..3a1c96c1b 100644 --- a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import override +from typing import cast, override import scale_gp_beta.lib.tracing as tracing from scale_gp_beta import SGPClient, AsyncSGPClient @@ -27,6 +27,39 @@ def _get_span_type(span: Span) -> str: return "STANDALONE" +def _add_source_to_span(span: Span, env_vars: EnvironmentVariables) -> None: + if span.data is None: + span.data = {} + if isinstance(span.data, dict): + span.data["__source__"] = "agentex" + if env_vars.ACP_TYPE is not None: + span.data["__acp_type__"] = env_vars.ACP_TYPE + if env_vars.AGENT_NAME is not None: + span.data["__agent_name__"] = env_vars.AGENT_NAME + if env_vars.AGENT_ID is not None: + span.data["__agent_id__"] = env_vars.AGENT_ID + + +def _build_sgp_span(span: Span, env_vars: EnvironmentVariables) -> SGPSpan: + """Build an SGPSpan from an agentex Span. Idempotent on span_id at the SGP backend.""" + _add_source_to_span(span, env_vars) + sgp_span = cast( + SGPSpan, + create_span( + name=span.name, + span_type=_get_span_type(span), + span_id=span.id, + parent_id=span.parent_id, + trace_id=span.trace_id, + input=span.input, + output=span.output, + metadata=span.data, + ), + ) + sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] + return sgp_span + + class SGPSyncTracingProcessor(SyncTracingProcessor): def __init__(self, config: SGPTracingProcessorConfig): disabled = config.sgp_api_key == "" or config.sgp_account_id == "" @@ -38,63 +71,27 @@ def __init__(self, config: SGPTracingProcessorConfig): ), disabled=disabled, ) - self._spans: dict[str, SGPSpan] = {} self.env_vars = EnvironmentVariables.refresh() - def _add_source_to_span(self, span: Span) -> None: - if span.data is None: - span.data = {} - if isinstance(span.data, dict): - span.data["__source__"] = "agentex" - if self.env_vars.ACP_TYPE is not None: - span.data["__acp_type__"] = self.env_vars.ACP_TYPE - if self.env_vars.AGENT_NAME is not None: - span.data["__agent_name__"] = self.env_vars.AGENT_NAME - if self.env_vars.AGENT_ID is not None: - span.data["__agent_id__"] = self.env_vars.AGENT_ID - @override def on_span_start(self, span: Span) -> None: - self._add_source_to_span(span) - - sgp_span = create_span( - name=span.name, - span_type=_get_span_type(span), - span_id=span.id, - parent_id=span.parent_id, - trace_id=span.trace_id, - input=span.input, - output=span.output, - metadata=span.data, - ) - sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] + sgp_span = _build_sgp_span(span, self.env_vars) sgp_span.flush(blocking=False) - self._spans[span.id] = sgp_span - @override def on_span_end(self, span: Span) -> None: - sgp_span = self._spans.pop(span.id, None) - if sgp_span is None: - logger.warning(f"Span {span.id} not found in stored spans, skipping span end") - return - - self._add_source_to_span(span) - sgp_span.output = span.output # type: ignore[assignment] - sgp_span.metadata = span.data # type: ignore[assignment] + sgp_span = _build_sgp_span(span, self.env_vars) sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] sgp_span.flush(blocking=False) @override def shutdown(self) -> None: - self._spans.clear() flush_queue() class SGPAsyncTracingProcessor(AsyncTracingProcessor): def __init__(self, config: SGPTracingProcessorConfig): self.disabled = config.sgp_api_key == "" or config.sgp_account_id == "" - self._spans: dict[str, SGPSpan] = {} import httpx # Disable keepalive so each HTTP call gets a fresh TCP connection, @@ -113,18 +110,6 @@ def __init__(self, config: SGPTracingProcessorConfig): ) self.env_vars = EnvironmentVariables.refresh() - def _add_source_to_span(self, span: Span) -> None: - if span.data is None: - span.data = {} - if isinstance(span.data, dict): - span.data["__source__"] = "agentex" - if self.env_vars.ACP_TYPE is not None: - span.data["__acp_type__"] = self.env_vars.ACP_TYPE - if self.env_vars.AGENT_NAME is not None: - span.data["__agent_name__"] = self.env_vars.AGENT_NAME - if self.env_vars.AGENT_ID is not None: - span.data["__agent_id__"] = self.env_vars.AGENT_ID - @override async def on_span_start(self, span: Span) -> None: await self.on_spans_start([span]) @@ -138,22 +123,7 @@ async def on_spans_start(self, spans: list[Span]) -> None: if not spans: return - sgp_spans: list[SGPSpan] = [] - for span in spans: - self._add_source_to_span(span) - sgp_span = create_span( - name=span.name, - span_type=_get_span_type(span), - span_id=span.id, - parent_id=span.parent_id, - trace_id=span.trace_id, - input=span.input, - output=span.output, - metadata=span.data, - ) - sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] - self._spans[span.id] = sgp_span - sgp_spans.append(sgp_span) + sgp_spans = [_build_sgp_span(span, self.env_vars) for span in spans] if self.disabled: logger.warning("SGP is disabled, skipping span upsert") @@ -167,29 +137,18 @@ async def on_spans_end(self, spans: list[Span]) -> None: if not spans: return - to_upsert: list[SGPSpan] = [] + sgp_spans: list[SGPSpan] = [] for span in spans: - sgp_span = self._spans.pop(span.id, None) - if sgp_span is None: - logger.warning(f"Span {span.id} not found in stored spans, skipping span end") - continue - - self._add_source_to_span(span) - sgp_span.input = span.input # type: ignore[assignment] - sgp_span.output = span.output # type: ignore[assignment] - sgp_span.metadata = span.data # type: ignore[assignment] + sgp_span = _build_sgp_span(span, self.env_vars) sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] - to_upsert.append(sgp_span) + sgp_spans.append(sgp_span) - if self.disabled or not to_upsert: + if self.disabled: return await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[s.to_request_params() for s in to_upsert] + items=[s.to_request_params() for s in sgp_spans] ) @override async def shutdown(self) -> None: - await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[sgp_span.to_request_params() for sgp_span in self._spans.values()] - ) - self._spans.clear() + pass diff --git a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py index 50d615e0d..4614fe540 100644 --- a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py +++ b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py @@ -41,18 +41,16 @@ def _make_mock_sgp_span() -> MagicMock: # --------------------------------------------------------------------------- -class TestSGPSyncTracingProcessorMemoryLeak: +class TestSGPSyncTracingProcessor: @staticmethod def _make_processor(): mock_env = MagicMock() mock_env.refresh.return_value = MagicMock(ACP_TYPE=None, AGENT_NAME=None, AGENT_ID=None) mock_create_span = MagicMock(side_effect=lambda **kwargs: _make_mock_sgp_span()) - with patch(f"{MODULE}.EnvironmentVariables", mock_env), \ - patch(f"{MODULE}.SGPClient"), \ - patch(f"{MODULE}.tracing"), \ - patch(f"{MODULE}.flush_queue"), \ - patch(f"{MODULE}.create_span", mock_create_span): + with patch(f"{MODULE}.EnvironmentVariables", mock_env), patch(f"{MODULE}.SGPClient"), patch( + f"{MODULE}.tracing" + ), patch(f"{MODULE}.flush_queue"), patch(f"{MODULE}.create_span", mock_create_span): from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( SGPSyncTracingProcessor, ) @@ -61,41 +59,50 @@ def _make_processor(): return processor, mock_create_span - def test_spans_not_leaked_after_completed_lifecycle(self): + def test_processor_holds_no_per_span_state(self): + """Stateless processor must not retain any per-span dict between lifecycle events.""" processor, _ = self._make_processor() + assert not hasattr(processor, "_spans") - with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + def test_span_lifecycle_produces_two_flushes(self): + """Each span produces one flush on start and one on end.""" + processor, _ = self._make_processor() + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()) as mock_cs: for _ in range(100): span = _make_span() processor.on_span_start(span) span.end_time = datetime.now(UTC) processor.on_span_end(span) - assert len(processor._spans) == 0, ( - f"Expected 0 spans after 100 complete lifecycles, got {len(processor._spans)} — memory leak!" - ) + # 100 spans × (1 start + 1 end) = 200 build calls. + assert mock_cs.call_count == 200 + + def test_span_end_without_prior_start_still_flushes(self): + """Cross-pod Temporal case: END activity lands on a pod that never saw START. - def test_spans_present_during_active_lifecycle(self): + Today this used to be a silent no-op. After the stateless refactor it + must still flush a complete span (start_time + end_time + payload). + """ processor, _ = self._make_processor() - with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): - span = _make_span() - processor.on_span_start(span) - assert len(processor._spans) == 1, "Span should be tracked while active" + captured_spans: list[MagicMock] = [] + def capture_create_span(**kwargs): + sgp_span = _make_mock_sgp_span() + captured_spans.append(sgp_span) + return sgp_span + + with patch(f"{MODULE}.create_span", side_effect=capture_create_span): + span = _make_span() span.end_time = datetime.now(UTC) + # No on_span_start — END lands here for the first time. processor.on_span_end(span) - assert len(processor._spans) == 0, "Span should be removed after end" - def test_span_end_for_unknown_span_is_noop(self): - processor, _ = self._make_processor() - - span = _make_span() - # End a span that was never started — should not raise - span.end_time = datetime.now(UTC) - processor.on_span_end(span) - - assert len(processor._spans) == 0 + assert len(captured_spans) == 1 + assert captured_spans[0].flush.called + assert captured_spans[0].start_time is not None + assert captured_spans[0].end_time is not None # --------------------------------------------------------------------------- @@ -103,7 +110,7 @@ def test_span_end_for_unknown_span_is_noop(self): # --------------------------------------------------------------------------- -class TestSGPAsyncTracingProcessorMemoryLeak: +class TestSGPAsyncTracingProcessor: @staticmethod def _make_processor(): mock_env = MagicMock() @@ -113,9 +120,9 @@ def _make_processor(): mock_async_client = MagicMock() mock_async_client.spans.upsert_batch = AsyncMock() - with patch(f"{MODULE}.EnvironmentVariables", mock_env), \ - patch(f"{MODULE}.create_span", mock_create_span), \ - patch(f"{MODULE}.AsyncSGPClient", return_value=mock_async_client): + with patch(f"{MODULE}.EnvironmentVariables", mock_env), patch(f"{MODULE}.create_span", mock_create_span), patch( + f"{MODULE}.AsyncSGPClient", return_value=mock_async_client + ): from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( SGPAsyncTracingProcessor, ) @@ -125,69 +132,78 @@ def _make_processor(): # Wire up the mock client after construction (constructor stores it) processor.sgp_async_client = mock_async_client - # Keep create_span mock active for on_span_start calls return processor, mock_create_span - async def test_spans_not_leaked_after_completed_lifecycle(self): + def test_processor_holds_no_per_span_state(self): + """Stateless processor must not retain any per-span dict between lifecycle events.""" processor, _ = self._make_processor() + assert not hasattr(processor, "_spans") - with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): - for _ in range(100): - span = _make_span() - await processor.on_span_start(span) - span.end_time = datetime.now(UTC) - await processor.on_span_end(span) - - assert len(processor._spans) == 0, ( - f"Expected 0 spans after 100 complete lifecycles, got {len(processor._spans)} — memory leak!" - ) - - async def test_spans_present_during_active_lifecycle(self): + async def test_span_lifecycle_produces_two_upserts(self): + """Each span produces one upsert_batch call on start and one on end.""" processor, _ = self._make_processor() with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): span = _make_span() await processor.on_span_start(span) - assert len(processor._spans) == 1, "Span should be tracked while active" - span.end_time = datetime.now(UTC) await processor.on_span_end(span) - assert len(processor._spans) == 0, "Span should be removed after end" - async def test_span_end_for_unknown_span_is_noop(self): + assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 + + async def test_span_end_without_prior_start_still_upserts(self): + """Cross-pod Temporal case: END activity lands on a pod that never saw START. + + Today this used to be a silent no-op. After the stateless refactor it + must still upsert a complete span via upsert_batch. + """ processor, _ = self._make_processor() - span = _make_span() - span.end_time = datetime.now(UTC) - await processor.on_span_end(span) + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + span = _make_span() + span.end_time = datetime.now(UTC) + # No on_span_start — END lands here for the first time. + await processor.on_span_end(span) - assert len(processor._spans) == 0 + assert processor.sgp_async_client.spans.upsert_batch.call_count == 1 + items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"] + assert len(items) == 1 - async def test_sgp_span_input_updated_on_end(self): - """on_span_end should update sgp_span.input from the incoming span.""" + async def test_sgp_span_input_and_output_propagated_on_end(self): + """on_span_end should send the span's current input and output via upsert_batch.""" processor, _ = self._make_processor() - with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + captured: list[MagicMock] = [] + + def capture_create_span(**kwargs): + sgp_span = _make_mock_sgp_span() + captured.append(sgp_span) + return sgp_span + + mock_create_span = MagicMock(side_effect=capture_create_span) + with patch(f"{MODULE}.create_span", mock_create_span): span = _make_span() span.input = {"messages": [{"role": "user", "content": "hello"}]} await processor.on_span_start(span) - assert len(processor._spans) == 1 - - # Simulate modified input at end time - updated_input: dict[str, object] = {"messages": [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hi"}, - ]} - span.input = updated_input - span.output = {"response": "hi"} - span.end_time = datetime.now(UTC) - await processor.on_span_end(span) - - # Span should be removed after end - assert len(processor._spans) == 0 - # The end upsert should have been called + span.input = { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ] + } + span.output = {"response": "hi"} + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end + # The end-time SGPSpan should have end_time populated. + end_span = captured[-1] + assert end_span.end_time is not None + # Verify the updated input/output reached create_span on the end call. + end_call_kwargs = mock_create_span.call_args_list[-1].kwargs + assert end_call_kwargs["input"]["messages"][-1]["role"] == "assistant" + assert end_call_kwargs["output"] == {"response": "hi"} async def test_on_spans_start_sends_single_upsert_for_batch(self): """Given N spans at once, on_spans_start should make ONE upsert_batch HTTP call.""" @@ -203,8 +219,6 @@ async def test_on_spans_start_sends_single_upsert_for_batch(self): ) items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"] assert len(items) == n - # All spans should be tracked for the subsequent end call - assert len(processor._spans) == n async def test_on_spans_end_sends_single_upsert_for_batch(self): """Given N spans at once, on_spans_end should make ONE upsert_batch HTTP call.""" @@ -226,4 +240,3 @@ async def test_on_spans_end_sends_single_upsert_for_batch(self): ) items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"] assert len(items) == n - assert len(processor._spans) == 0