11import contextlib
2- import logging
32import os
3+ from contextvars import ContextVar
44from datetime import datetime
5- from typing import Any , Dict , Iterator , Optional , Union
5+ from typing import Any , Dict , Iterator , List , Optional , Union
66
7+ from haystack import logging
78from haystack .components .generators .openai_utils import _convert_message_to_openai_format
89from haystack .dataclasses import ChatMessage
910from haystack .tracing import Span , Tracer , tracer
3233]
3334_ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS
3435
36+ # These are the keys used by Haystack for traces and span.
37+ # We keep them here to avoid making typos when using them.
38+ _PIPELINE_RUN_KEY = "haystack.pipeline.run"
39+ _COMPONENT_NAME_KEY = "haystack.component.name"
40+ _COMPONENT_TYPE_KEY = "haystack.component.type"
41+ _COMPONENT_OUTPUT_KEY = "haystack.component.output"
42+
43+ # Context var used to keep track of tracing related info.
44+ # This mainly useful for parents spans.
45+ tracing_context_var : ContextVar [Dict [Any , Any ]] = ContextVar ("tracing_context" , default = {})
46+
3547
3648class LangfuseSpan (Span ):
3749 """
@@ -86,7 +98,7 @@ def set_content_tag(self, key: str, value: Any) -> None:
8698
8799 self ._data [key ] = value
88100
89- def raw_span (self ) -> Any :
101+ def raw_span (self ) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]" :
90102 """
91103 Return the underlying span instance.
92104
@@ -115,41 +127,57 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public:
115127 and only accessible to the Langfuse account owner.
116128 """
117129 self ._tracer = tracer
118- self ._context : list [LangfuseSpan ] = []
130+ self ._context : List [LangfuseSpan ] = []
119131 self ._name = name
120132 self ._public = public
121133 self .enforce_flush = os .getenv (HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR , "true" ).lower () == "true"
122134
123135 @contextlib .contextmanager
124- def trace (self , operation_name : str , tags : Optional [Dict [str , Any ]] = None ) -> Iterator [Span ]:
125- """
126- Start and manage a new trace span.
127- :param operation_name: The name of the operation.
128- :param tags: A dictionary of tags to attach to the span.
129- :return: A context manager yielding the span.
130- """
136+ def trace (
137+ self , operation_name : str , tags : Optional [Dict [str , Any ]] = None , parent_span : Optional [Span ] = None
138+ ) -> Iterator [Span ]:
131139 tags = tags or {}
132- span_name = tags .get ("haystack.component.name" , operation_name )
133-
134- if tags .get ("haystack.component.type" ) in _ALL_SUPPORTED_GENERATORS :
135- span = LangfuseSpan (self .current_span ().raw_span ().generation (name = span_name ))
140+ span_name = tags .get (_COMPONENT_NAME_KEY , operation_name )
141+
142+ # Create new span depending whether there's a parent span or not
143+ if not parent_span :
144+ if operation_name != _PIPELINE_RUN_KEY :
145+ logger .warning (
146+ "Creating a new trace without a parent span is not recommended for operation '{operation_name}'." ,
147+ operation_name = operation_name ,
148+ )
149+ # Create a new trace if no parent span is provided
150+ span = LangfuseSpan (
151+ self ._tracer .trace (
152+ name = self ._name ,
153+ public = self ._public ,
154+ id = tracing_context_var .get ().get ("trace_id" ),
155+ user_id = tracing_context_var .get ().get ("user_id" ),
156+ session_id = tracing_context_var .get ().get ("session_id" ),
157+ tags = tracing_context_var .get ().get ("tags" ),
158+ version = tracing_context_var .get ().get ("version" ),
159+ )
160+ )
161+ elif tags .get (_COMPONENT_TYPE_KEY ) in _ALL_SUPPORTED_GENERATORS :
162+ span = LangfuseSpan (parent_span .raw_span ().generation (name = span_name ))
136163 else :
137- span = LangfuseSpan (self . current_span () .raw_span ().span (name = span_name ))
164+ span = LangfuseSpan (parent_span .raw_span ().span (name = span_name ))
138165
139166 self ._context .append (span )
140167 span .set_tags (tags )
141168
142169 yield span
143170
144- if tags .get ("haystack.component.type" ) in _SUPPORTED_GENERATORS :
145- meta = span ._data .get ("haystack.component.output" , {}).get ("meta" )
171+ # Update span metadata based on component type
172+ if tags .get (_COMPONENT_TYPE_KEY ) in _SUPPORTED_GENERATORS :
173+ # Haystack returns one meta dict for each message, but the 'usage' value
174+ # is always the same, let's just pick the first item
175+ meta = span ._data .get (_COMPONENT_OUTPUT_KEY , {}).get ("meta" )
146176 if meta :
147- # Haystack returns one meta dict for each message, but the 'usage' value
148- # is always the same, let's just pick the first item
149177 m = meta [0 ]
150178 span ._span .update (usage = m .get ("usage" ) or None , model = m .get ("model" ))
151- elif tags .get ("haystack.component.type" ) in _SUPPORTED_CHAT_GENERATORS :
152- replies = span ._data .get ("haystack.component.output" , {}).get ("replies" )
179+ elif tags .get (_COMPONENT_TYPE_KEY ) in _SUPPORTED_CHAT_GENERATORS :
180+ replies = span ._data .get (_COMPONENT_OUTPUT_KEY , {}).get ("replies" )
153181 if replies :
154182 meta = replies [0 ].meta
155183 completion_start_time = meta .get ("completion_start_time" )
@@ -165,36 +193,24 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I
165193 completion_start_time = completion_start_time ,
166194 )
167195
168- pipeline_input = tags .get ("haystack.pipeline.input_data" , None )
169- if pipeline_input :
170- span ._span .update (input = tags ["haystack.pipeline.input_data" ])
171- pipeline_output = tags .get ("haystack.pipeline.output_data" , None )
172- if pipeline_output :
173- span ._span .update (output = tags ["haystack.pipeline.output_data" ])
174-
175- span .raw_span ().end ()
196+ raw_span = span .raw_span ()
197+ if isinstance (raw_span , langfuse .client .StatefulSpanClient ):
198+ raw_span .end ()
176199 self ._context .pop ()
177200
178- if len (self ._context ) == 1 :
179- # The root span has to be a trace, which need to be removed from the context after the pipeline run
180- self ._context .pop ()
181-
182- if self .enforce_flush :
183- self .flush ()
201+ if self .enforce_flush :
202+ self .flush ()
184203
185204 def flush (self ):
186205 self ._tracer .flush ()
187206
188- def current_span (self ) -> Span :
207+ def current_span (self ) -> Optional [ Span ] :
189208 """
190- Return the currently active span.
209+ Return the current active span.
191210
192- :return: The currently active span.
211+ :return: The current span if available, else None .
193212 """
194- if not self ._context :
195- # The root span has to be a trace
196- self ._context .append (LangfuseSpan (self ._tracer .trace (name = self ._name , public = self ._public )))
197- return self ._context [- 1 ]
213+ return self ._context [- 1 ] if self ._context else None
198214
199215 def get_trace_url (self ) -> str :
200216 """
0 commit comments