Skip to content

Commit 1e3ac82

Browse files
alex-stoicavblagojesilvanocerza
authored
Fixed TypeError in LangfuseTrace (#1184)
* Added parent_span functionality in trace method * solved PR comments * Readded "end()" for solving Latency issues * chore: fix ruff linting * Handle multiple runs * Fix indentation and span closing * Fix tests --------- Co-authored-by: Vladimir Blagojevic <[email protected]> Co-authored-by: Silvano Cerza <[email protected]>
1 parent 67e08d0 commit 1e3ac82

File tree

3 files changed

+77
-58
lines changed

3 files changed

+77
-58
lines changed

integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import contextlib
2-
import logging
32
import os
3+
from contextvars import ContextVar
44
from datetime import datetime
5-
from typing import Any, Dict, Iterator, Optional, Union
5+
from typing import Any, Dict, Iterator, List, Optional, Union
66

7+
from haystack import logging
78
from haystack.components.generators.openai_utils import _convert_message_to_openai_format
89
from haystack.dataclasses import ChatMessage
910
from haystack.tracing import Span, Tracer, tracer
@@ -32,6 +33,17 @@
3233
]
3334
_ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS
3435

36+
# These are the keys used by Haystack for traces and span.
37+
# We keep them here to avoid making typos when using them.
38+
_PIPELINE_RUN_KEY = "haystack.pipeline.run"
39+
_COMPONENT_NAME_KEY = "haystack.component.name"
40+
_COMPONENT_TYPE_KEY = "haystack.component.type"
41+
_COMPONENT_OUTPUT_KEY = "haystack.component.output"
42+
43+
# Context var used to keep track of tracing related info.
44+
# This mainly useful for parents spans.
45+
tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context", default={})
46+
3547

3648
class LangfuseSpan(Span):
3749
"""
@@ -86,7 +98,7 @@ def set_content_tag(self, key: str, value: Any) -> None:
8698

8799
self._data[key] = value
88100

89-
def raw_span(self) -> Any:
101+
def raw_span(self) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]":
90102
"""
91103
Return the underlying span instance.
92104
@@ -115,41 +127,57 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public:
115127
and only accessible to the Langfuse account owner.
116128
"""
117129
self._tracer = tracer
118-
self._context: list[LangfuseSpan] = []
130+
self._context: List[LangfuseSpan] = []
119131
self._name = name
120132
self._public = public
121133
self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true"
122134

123135
@contextlib.contextmanager
124-
def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> Iterator[Span]:
125-
"""
126-
Start and manage a new trace span.
127-
:param operation_name: The name of the operation.
128-
:param tags: A dictionary of tags to attach to the span.
129-
:return: A context manager yielding the span.
130-
"""
136+
def trace(
137+
self, operation_name: str, tags: Optional[Dict[str, Any]] = None, parent_span: Optional[Span] = None
138+
) -> Iterator[Span]:
131139
tags = tags or {}
132-
span_name = tags.get("haystack.component.name", operation_name)
133-
134-
if tags.get("haystack.component.type") in _ALL_SUPPORTED_GENERATORS:
135-
span = LangfuseSpan(self.current_span().raw_span().generation(name=span_name))
140+
span_name = tags.get(_COMPONENT_NAME_KEY, operation_name)
141+
142+
# Create new span depending whether there's a parent span or not
143+
if not parent_span:
144+
if operation_name != _PIPELINE_RUN_KEY:
145+
logger.warning(
146+
"Creating a new trace without a parent span is not recommended for operation '{operation_name}'.",
147+
operation_name=operation_name,
148+
)
149+
# Create a new trace if no parent span is provided
150+
span = LangfuseSpan(
151+
self._tracer.trace(
152+
name=self._name,
153+
public=self._public,
154+
id=tracing_context_var.get().get("trace_id"),
155+
user_id=tracing_context_var.get().get("user_id"),
156+
session_id=tracing_context_var.get().get("session_id"),
157+
tags=tracing_context_var.get().get("tags"),
158+
version=tracing_context_var.get().get("version"),
159+
)
160+
)
161+
elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS:
162+
span = LangfuseSpan(parent_span.raw_span().generation(name=span_name))
136163
else:
137-
span = LangfuseSpan(self.current_span().raw_span().span(name=span_name))
164+
span = LangfuseSpan(parent_span.raw_span().span(name=span_name))
138165

139166
self._context.append(span)
140167
span.set_tags(tags)
141168

142169
yield span
143170

144-
if tags.get("haystack.component.type") in _SUPPORTED_GENERATORS:
145-
meta = span._data.get("haystack.component.output", {}).get("meta")
171+
# Update span metadata based on component type
172+
if tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_GENERATORS:
173+
# Haystack returns one meta dict for each message, but the 'usage' value
174+
# is always the same, let's just pick the first item
175+
meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta")
146176
if meta:
147-
# Haystack returns one meta dict for each message, but the 'usage' value
148-
# is always the same, let's just pick the first item
149177
m = meta[0]
150178
span._span.update(usage=m.get("usage") or None, model=m.get("model"))
151-
elif tags.get("haystack.component.type") in _SUPPORTED_CHAT_GENERATORS:
152-
replies = span._data.get("haystack.component.output", {}).get("replies")
179+
elif tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_CHAT_GENERATORS:
180+
replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies")
153181
if replies:
154182
meta = replies[0].meta
155183
completion_start_time = meta.get("completion_start_time")
@@ -165,36 +193,24 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I
165193
completion_start_time=completion_start_time,
166194
)
167195

168-
pipeline_input = tags.get("haystack.pipeline.input_data", None)
169-
if pipeline_input:
170-
span._span.update(input=tags["haystack.pipeline.input_data"])
171-
pipeline_output = tags.get("haystack.pipeline.output_data", None)
172-
if pipeline_output:
173-
span._span.update(output=tags["haystack.pipeline.output_data"])
174-
175-
span.raw_span().end()
196+
raw_span = span.raw_span()
197+
if isinstance(raw_span, langfuse.client.StatefulSpanClient):
198+
raw_span.end()
176199
self._context.pop()
177200

178-
if len(self._context) == 1:
179-
# The root span has to be a trace, which need to be removed from the context after the pipeline run
180-
self._context.pop()
181-
182-
if self.enforce_flush:
183-
self.flush()
201+
if self.enforce_flush:
202+
self.flush()
184203

185204
def flush(self):
186205
self._tracer.flush()
187206

188-
def current_span(self) -> Span:
207+
def current_span(self) -> Optional[Span]:
189208
"""
190-
Return the currently active span.
209+
Return the current active span.
191210
192-
:return: The currently active span.
211+
:return: The current span if available, else None.
193212
"""
194-
if not self._context:
195-
# The root span has to be a trace
196-
self._context.append(LangfuseSpan(self._tracer.trace(name=self._name, public=self._public)))
197-
return self._context[-1]
213+
return self._context[-1] if self._context else None
198214

199215
def get_trace_url(self) -> str:
200216
"""

integrations/langfuse/tests/test_tracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_create_new_span(self):
6969
tracer = LangfuseTracer(tracer=mock_tracer, name="Haystack", public=False)
7070

7171
with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span:
72-
assert len(tracer._context) == 2, "The trace span should have been added to the the root context span"
72+
assert len(tracer._context) == 1, "The trace span should have been added to the the root context span"
7373
assert span.raw_span().operation_name == "operation_name"
7474
assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"}
7575

integrations/langfuse/tests/test_tracing.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,28 @@ def test_tracing_integration(llm_class, env_var, expected_trace):
5252
assert "Berlin" in response["llm"]["replies"][0].content
5353
assert response["tracer"]["trace_url"]
5454

55-
# add a random delay between 1 and 3 seconds to make sure the trace is flushed
56-
# and that the trace is available in Langfuse when we fetch it below
57-
time.sleep(random.uniform(1, 3))
58-
59-
url = "https://cloud.langfuse.com/api/public/traces/"
6055
trace_url = response["tracer"]["trace_url"]
6156
uuid = os.path.basename(urlparse(trace_url).path)
57+
url = f"https://cloud.langfuse.com/api/public/traces/{uuid}"
6258

63-
try:
64-
response = requests.get(
65-
url + uuid, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"])
59+
# Poll the Langfuse API a bit as the trace might not be ready right away
60+
attempts = 5
61+
delay = 1
62+
while attempts >= 0:
63+
res = requests.get(
64+
url, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"])
6665
)
67-
assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}"
66+
if attempts > 0 and res.status_code != 200:
67+
attempts -= 1
68+
time.sleep(delay)
69+
delay *= 2
70+
continue
71+
assert res.status_code == 200, f"Failed to retrieve data from Langfuse API: {res.status_code}"
6872

6973
# check if the trace contains the expected LLM name
70-
assert expected_trace in str(response.content)
74+
assert expected_trace in str(res.content)
7175
# check if the trace contains the expected generation span
72-
assert "GENERATION" in str(response.content)
76+
assert "GENERATION" in str(res.content)
7377
# check if the trace contains the expected user_id
74-
assert "user_42" in str(response.content)
75-
except requests.exceptions.RequestException as e:
76-
pytest.fail(f"Failed to retrieve data from Langfuse API: {e}")
78+
assert "user_42" in str(res.content)
79+
break

0 commit comments

Comments
 (0)