Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ venv.bak/
/.vscode
/output
dist/
.coda/
.coda/
.DS_Store
6 changes: 6 additions & 0 deletions CHANGLOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## [0.1.23] - 2026-01-07
### Added
- lcc support child_of and state_span_ctx_key
- lcc support multi clients
- llc support get trace_id and root_span_id

## [0.1.22] - 2026-01-06
### Added
- support span discard
Expand Down
132 changes: 98 additions & 34 deletions cozeloop/integration/langchain/trace_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations
import json
import threading
import time
import traceback
from typing import List, Dict, Union, Any, Optional, Callable, Protocol
Expand All @@ -24,8 +25,6 @@
from cozeloop.integration.langchain.trace_model.runtime import RuntimeInfo
from cozeloop.integration.langchain.util import calc_token_usage, get_prompt_tag

_trace_callback_client: Optional[Client] = None


class LoopTracer:
@classmethod
Expand All @@ -35,19 +34,28 @@ def get_callback_handler(
modify_name_fn: Optional[Callable[[str], str]] = None,
add_tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
tags: Dict[str, Any] = None,
child_of: Optional[Span] = None,
state_span_ctx_key: str = None,
):
"""
Do not hold it for a long time, get a new callback_handler for each request.
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
client: cozeloop client instance. If not provided, use the default client.
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
It's priority higher than parameter tags.
tags: default tags dict. It's priority lower than parameter add_tags_fn.
child_of: parent span of this callback_handler.
state_span_ctx_key: span context field name in state. If provided, you need set the field in sate, and we will use it to set span context in state.
You can get it from state for creating inner span in async node.
"""
global _trace_callback_client
if client:
_trace_callback_client = client
else:
_trace_callback_client = get_default_client()

return LoopTraceCallbackHandler(modify_name_fn, add_tags_fn, tags)
return LoopTraceCallbackHandler(
name_fn=modify_name_fn,
tags_fn=add_tags_fn,
tags=tags,
child_of=child_of,
client=client,
state_span_ctx_key=state_span_ctx_key,
)


class LoopTraceCallbackHandler(BaseCallbackHandler):
Expand All @@ -56,13 +64,22 @@ def __init__(
name_fn: Optional[Callable[[str], str]] = None,
tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
tags: Dict[str, Any] = None,
child_of: Optional[Span] = None,
client: Client = None,
state_span_ctx_key: str = None,
):
super().__init__()
self._space_id = _trace_callback_client.workspace_id
self._client = client if client else get_default_client()
self._space_id = self._client.workspace_id
self.run_map: Dict[str, Run] = {}
self.name_fn = name_fn
self.tags_fn = tags_fn
self._tags = tags if tags else {}
self.trace_id: Optional[str] = None
self.root_span_id: Optional[str] = None
self._id_lock = threading.Lock()
self._child_of = child_of
self._state_span_ctx_key = state_span_ctx_key

def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
span_tags = {}
Expand All @@ -73,14 +90,16 @@ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs:
span_tags['input'] = ModelTraceInput([BaseMessage(type='', content=prompt) for prompt in prompts],
kwargs.get('invocation_params', {})).to_json()
except Exception as e:
flow_span.set_error(e)
span_tags['internal_error'] = repr(e)
span_tags['internal_error_trace'] = traceback.format_exc()
finally:
span_tags.update(_get_model_span_tags(**kwargs))
self._set_span_tags(flow_span, span_tags)
# Store some pre-aspect information.
self.run_map[str(kwargs['run_id'])].model_meta = ModelMeta(
message=[{'role': '', 'content': prompt} for prompt in prompts],
model_name=span_tags.get('model_name', ''))
return flow_span

def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any) -> Any:
span_tags = {}
Expand All @@ -90,14 +109,16 @@ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Ba
try:
span_tags['input'] = ModelTraceInput(messages, kwargs.get('invocation_params', {})).to_json()
except Exception as e:
flow_span.set_error(e)
span_tags['internal_error'] = repr(e)
span_tags['internal_error_trace'] = traceback.format_exc()
finally:
span_tags.update(_get_model_span_tags(**kwargs))
self._set_span_tags(flow_span, span_tags)
# Store some pre-aspect information.
self.run_map[str(kwargs['run_id'])].model_meta = (
ModelMeta(message=[{'role': message.type, 'content': message.content} for inner_messages in messages for
message in inner_messages], model_name=span_tags.get('model_name', '')))
return flow_span

async def on_llm_new_token(self, token: str, *, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
**kwargs: Any) -> None:
Expand All @@ -119,10 +140,14 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
if tags:
self._set_span_tags(flow_span, tags, need_convert_tag_value=False)
except Exception as e:
flow_span.set_error(e)
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
# finish flow_span
self._end_flow_span(flow_span)

def on_llm_error(self, error: Exception, **kwargs: Any) -> Any:
self.on_chain_error(error, **kwargs)

def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
flow_span = None
try:
Expand All @@ -131,13 +156,27 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **k
self._on_prompt_start(flow_span, serialized, inputs, **kwargs)
else:
span_type = 'chain'
if kwargs['name'] == 'LangGraph': # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
if kwargs[
'name'] == 'LangGraph': # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent
span_type = 'graph'
flow_span = self._new_flow_span(kwargs['name'], span_type, **kwargs)
flow_span.set_tags({'input': _convert_2_json(inputs)})
except Exception as e:
if flow_span is not None:
flow_span.set_error(e)
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
finally:
if flow_span is not None:
# set trace_id
with self._id_lock:
if hasattr(flow_span, 'context'):
self.trace_id = flow_span.context.trace_id
else:
self.trace_id = flow_span.trace_id
# set span_ctx in state
if self._state_span_ctx_key:
inputs[self._state_span_ctx_key] = flow_span
return flow_span

def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> Any:
flow_span = self.run_map[str(kwargs['run_id'])].span
Expand All @@ -151,16 +190,17 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
else:
flow_span.set_tags({'output': _convert_2_json(outputs)})
except Exception as e:
flow_span.set_error(e)
if flow_span:
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
self._end_flow_span(flow_span)

def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
flow_span = self._get_flow_span(**kwargs)
if flow_span is None:
span_name = '_Exception' if isinstance(error, Exception) else '_KeyboardInterrupt'
flow_span = self._new_flow_span(span_name, 'chain_error', **kwargs)
flow_span.set_error(error)
flow_span.set_tags({'error_trace': traceback.format_exc()})
flow_span.set_tags({'error': repr(error), 'error_trace': traceback.format_exc()})
self._end_flow_span(flow_span)

def on_tool_start(
Expand All @@ -170,13 +210,15 @@ def on_tool_start(
span_name = serialized.get('name', 'unknown')
flow_span = self._new_flow_span(span_name, 'tool', **kwargs)
self._set_span_tags(flow_span, span_tags)
return flow_span

def on_tool_end(self, output: str, **kwargs: Any) -> Any:
flow_span = self._get_flow_span(**kwargs)
try:
flow_span.set_tags({'output': _convert_2_json(output)})
except Exception as e:
flow_span.set_error(e)
span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()}
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
self._end_flow_span(flow_span)

def on_tool_error(
Expand All @@ -186,8 +228,8 @@ def on_tool_error(
if flow_span is None:
span_name = '_Exception' if isinstance(error, Exception) else '_KeyboardInterrupt'
flow_span = self._new_flow_span(span_name, 'tool_error', **kwargs)
flow_span.set_error(error)
flow_span.set_tags({'error_trace': traceback.format_exc()})
span_tags = {'error': repr(error), 'error_trace': traceback.format_exc()}
self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False)
self._end_flow_span(flow_span)

def on_text(self, text: str, **kwargs: Any) -> Any:
Expand All @@ -200,7 +242,8 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
return

def _end_flow_span(self, span: Span):
span.finish()
if span:
span.finish()

def _get_model_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
return self._get_model_token_tags(response, **kwargs)
Expand All @@ -224,20 +267,25 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
result['input_cached_tokens'] = input_cached_tokens
elif response.generations is not None and len(response.generations) > 0 and response.generations[0] is not None:
for i, generation in enumerate(response.generations[0]):
if isinstance(generation, ChatGeneration) and isinstance(generation.message,(AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
if isinstance(generation, ChatGeneration) and isinstance(generation.message, (
AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
is_get_from_langchain = True
result['input_tokens'] = generation.message.usage_metadata.get('input_tokens', 0)
result['output_tokens'] = generation.message.usage_metadata.get('output_tokens', 0)
result['tokens'] = result['input_tokens'] + result['output_tokens']
if generation.message.usage_metadata.get('output_token_details', {}):
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get('reasoning', 0)
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get(
'reasoning', 0)
if reasoning_tokens:
result['reasoning_tokens'] = reasoning_tokens
if generation.message.usage_metadata.get('input_token_details', {}):
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_read', 0)
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get(
'cache_read', 0)
if input_read_cached_tokens:
result['input_cached_tokens'] = input_read_cached_tokens
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_creation', 0)
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details',
{}).get('cache_creation',
0)
if input_creation_cached_tokens:
result['input_creation_cached_tokens'] = input_creation_cached_tokens
if is_get_from_langchain:
Expand All @@ -259,7 +307,8 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str,
span_tags = {'error_info': repr(e), 'error_trace': traceback.format_exc()}
return span_tags

def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None:
def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str),
**kwargs: Any) -> None:
# get inputs
params: List[Argument] = []
if isinstance(inputs, str):
Expand Down Expand Up @@ -309,8 +358,14 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
span_name = node_name
# set parent span
parent_span: Span = None
is_root_span = False
if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None and str(kwargs['parent_run_id']) in self.run_map:
parent_span = self.run_map[str(kwargs['parent_run_id'])].span
# only root span use child_of
if parent_span is None:
is_root_span = True
if self._child_of:
parent_span = self._child_of
# modify name
error_tag = {}
try:
Expand All @@ -321,15 +376,20 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
except Exception as e:
error_tag = {'error_info': f'name_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
# new span
flow_span = _trace_callback_client.start_span(span_name, span_type, child_of=parent_span)
flow_span = self._client.start_span(span_name, span_type, child_of=parent_span)
if is_root_span:
if hasattr(flow_span, 'context'):
self.root_span_id = flow_span.context.span_id
else:
self.trace_id = flow_span.span_id
run_id = str(kwargs['run_id'])
self.run_map[run_id] = Run(run_id, flow_span, span_type)
# set runtime
flow_span.set_runtime(RuntimeInfo())
# set extra tags
flow_span.set_tags(self._tags) # global tags
flow_span.set_tags(self._tags) # global tags
try:
if self.tags_fn: # add tags fn
if self.tags_fn: # add tags fn
tags = self.tags_fn(node_name)
if isinstance(tags, dict):
flow_span.set_tags(tags)
Expand Down Expand Up @@ -365,7 +425,10 @@ def _set_extra_span_tags(self, flow_span: Span, tag_list: list, **kwargs: Any):
class Run:
def __init__(self, run_id: str, span: Span, span_type: str) -> None:
self.run_id = run_id # langchain run_id
self.span_id = span.span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
if hasattr(span, 'context'):
self.span_id = span.context.span_id
else:
self.span_id = span.span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping.
self.span = span
self.span_type = span_type
self.child_runs: List[Run] = Field(default_factory=list)
Expand Down Expand Up @@ -519,7 +582,8 @@ def _convert_inputs(inputs: Any) -> Any:
format_inputs['content'] = inputs.content
return format_inputs
if isinstance(inputs, BaseMessage):
message = Message(role=inputs.type, content=inputs.content, tool_calls=inputs.additional_kwargs.get('tool_calls', []))
message = Message(role=inputs.type, content=inputs.content,
tool_calls=inputs.additional_kwargs.get('tool_calls', []))
return message
if isinstance(inputs, ChatPromptValue):
return _convert_inputs(inputs.messages)
Expand Down
Loading