diff --git a/.gitignore b/.gitignore index 74c712b..ba9b389 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,5 @@ venv.bak/ /.vscode /output dist/ -.coda/ \ No newline at end of file +.coda/ +.DS_Store \ No newline at end of file diff --git a/CHANGLOG.md b/CHANGLOG.md index 73df7fd..db575b5 100644 --- a/CHANGLOG.md +++ b/CHANGLOG.md @@ -1,3 +1,9 @@ +## [0.1.23] - 2026-01-07 +### Added +- lcc support child_of and state_span_ctx_key +- lcc support multi clients +- llc support get trace_id and root_span_id + ## [0.1.22] - 2026-01-06 ### Added - support span discard diff --git a/cozeloop/integration/langchain/trace_callback.py b/cozeloop/integration/langchain/trace_callback.py index 0e86c07..0010a12 100644 --- a/cozeloop/integration/langchain/trace_callback.py +++ b/cozeloop/integration/langchain/trace_callback.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import threading import time import traceback from typing import List, Dict, Union, Any, Optional, Callable, Protocol @@ -24,8 +25,6 @@ from cozeloop.integration.langchain.trace_model.runtime import RuntimeInfo from cozeloop.integration.langchain.util import calc_token_usage, get_prompt_tag -_trace_callback_client: Optional[Client] = None - class LoopTracer: @classmethod @@ -35,19 +34,28 @@ def get_callback_handler( modify_name_fn: Optional[Callable[[str], str]] = None, add_tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None, tags: Dict[str, Any] = None, + child_of: Optional[Span] = None, + state_span_ctx_key: str = None, ): """ Do not hold it for a long time, get a new callback_handler for each request. - modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name. - add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict. + client: cozeloop client instance. If not provided, use the default client. + modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name. + add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict. + It's priority higher than parameter tags. + tags: default tags dict. It's priority lower than parameter add_tags_fn. + child_of: parent span of this callback_handler. + state_span_ctx_key: span context field name in state. If provided, you need set the field in sate, and we will use it to set span context in state. + You can get it from state for creating inner span in async node. """ - global _trace_callback_client - if client: - _trace_callback_client = client - else: - _trace_callback_client = get_default_client() - - return LoopTraceCallbackHandler(modify_name_fn, add_tags_fn, tags) + return LoopTraceCallbackHandler( + name_fn=modify_name_fn, + tags_fn=add_tags_fn, + tags=tags, + child_of=child_of, + client=client, + state_span_ctx_key=state_span_ctx_key, + ) class LoopTraceCallbackHandler(BaseCallbackHandler): @@ -56,13 +64,22 @@ def __init__( name_fn: Optional[Callable[[str], str]] = None, tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None, tags: Dict[str, Any] = None, + child_of: Optional[Span] = None, + client: Client = None, + state_span_ctx_key: str = None, ): super().__init__() - self._space_id = _trace_callback_client.workspace_id + self._client = client if client else get_default_client() + self._space_id = self._client.workspace_id self.run_map: Dict[str, Run] = {} self.name_fn = name_fn self.tags_fn = tags_fn self._tags = tags if tags else {} + self.trace_id: Optional[str] = None + self.root_span_id: Optional[str] = None + self._id_lock = threading.Lock() + self._child_of = child_of + self._state_span_ctx_key = state_span_ctx_key def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any: span_tags = {} @@ -73,7 +90,8 @@ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: span_tags['input'] = ModelTraceInput([BaseMessage(type='', content=prompt) for prompt in prompts], kwargs.get('invocation_params', {})).to_json() except Exception as e: - flow_span.set_error(e) + span_tags['internal_error'] = repr(e) + span_tags['internal_error_trace'] = traceback.format_exc() finally: span_tags.update(_get_model_span_tags(**kwargs)) self._set_span_tags(flow_span, span_tags) @@ -81,6 +99,7 @@ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: self.run_map[str(kwargs['run_id'])].model_meta = ModelMeta( message=[{'role': '', 'content': prompt} for prompt in prompts], model_name=span_tags.get('model_name', '')) + return flow_span def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any) -> Any: span_tags = {} @@ -90,7 +109,8 @@ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Ba try: span_tags['input'] = ModelTraceInput(messages, kwargs.get('invocation_params', {})).to_json() except Exception as e: - flow_span.set_error(e) + span_tags['internal_error'] = repr(e) + span_tags['internal_error_trace'] = traceback.format_exc() finally: span_tags.update(_get_model_span_tags(**kwargs)) self._set_span_tags(flow_span, span_tags) @@ -98,6 +118,7 @@ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Ba self.run_map[str(kwargs['run_id'])].model_meta = ( ModelMeta(message=[{'role': message.type, 'content': message.content} for inner_messages in messages for message in inner_messages], model_name=span_tags.get('model_name', ''))) + return flow_span async def on_llm_new_token(self, token: str, *, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, **kwargs: Any) -> None: @@ -119,10 +140,14 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: if tags: self._set_span_tags(flow_span, tags, need_convert_tag_value=False) except Exception as e: - flow_span.set_error(e) + span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()} + self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False) # finish flow_span self._end_flow_span(flow_span) + def on_llm_error(self, error: Exception, **kwargs: Any) -> Any: + self.on_chain_error(error, **kwargs) + def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any: flow_span = None try: @@ -131,13 +156,27 @@ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **k self._on_prompt_start(flow_span, serialized, inputs, **kwargs) else: span_type = 'chain' - if kwargs['name'] == 'LangGraph': # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent + if kwargs[ + 'name'] == 'LangGraph': # LangGraph is Graph span_type,for trajectory evaluation aggregate to an agent span_type = 'graph' flow_span = self._new_flow_span(kwargs['name'], span_type, **kwargs) flow_span.set_tags({'input': _convert_2_json(inputs)}) except Exception as e: if flow_span is not None: - flow_span.set_error(e) + span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()} + self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False) + finally: + if flow_span is not None: + # set trace_id + with self._id_lock: + if hasattr(flow_span, 'context'): + self.trace_id = flow_span.context.trace_id + else: + self.trace_id = flow_span.trace_id + # set span_ctx in state + if self._state_span_ctx_key: + inputs[self._state_span_ctx_key] = flow_span + return flow_span def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> Any: flow_span = self.run_map[str(kwargs['run_id'])].span @@ -151,7 +190,9 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An else: flow_span.set_tags({'output': _convert_2_json(outputs)}) except Exception as e: - flow_span.set_error(e) + if flow_span: + span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()} + self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False) self._end_flow_span(flow_span) def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: @@ -159,8 +200,7 @@ def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: A if flow_span is None: span_name = '_Exception' if isinstance(error, Exception) else '_KeyboardInterrupt' flow_span = self._new_flow_span(span_name, 'chain_error', **kwargs) - flow_span.set_error(error) - flow_span.set_tags({'error_trace': traceback.format_exc()}) + flow_span.set_tags({'error': repr(error), 'error_trace': traceback.format_exc()}) self._end_flow_span(flow_span) def on_tool_start( @@ -170,13 +210,15 @@ def on_tool_start( span_name = serialized.get('name', 'unknown') flow_span = self._new_flow_span(span_name, 'tool', **kwargs) self._set_span_tags(flow_span, span_tags) + return flow_span def on_tool_end(self, output: str, **kwargs: Any) -> Any: flow_span = self._get_flow_span(**kwargs) try: flow_span.set_tags({'output': _convert_2_json(output)}) except Exception as e: - flow_span.set_error(e) + span_tags = {"internal_error": repr(e), 'internal_error_trace': traceback.format_exc()} + self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False) self._end_flow_span(flow_span) def on_tool_error( @@ -186,8 +228,8 @@ def on_tool_error( if flow_span is None: span_name = '_Exception' if isinstance(error, Exception) else '_KeyboardInterrupt' flow_span = self._new_flow_span(span_name, 'tool_error', **kwargs) - flow_span.set_error(error) - flow_span.set_tags({'error_trace': traceback.format_exc()}) + span_tags = {'error': repr(error), 'error_trace': traceback.format_exc()} + self._set_span_tags(flow_span, span_tags, need_convert_tag_value=False) self._end_flow_span(flow_span) def on_text(self, text: str, **kwargs: Any) -> Any: @@ -200,7 +242,8 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: return def _end_flow_span(self, span: Span): - span.finish() + if span: + span.finish() def _get_model_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]: return self._get_model_token_tags(response, **kwargs) @@ -224,20 +267,25 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, result['input_cached_tokens'] = input_cached_tokens elif response.generations is not None and len(response.generations) > 0 and response.generations[0] is not None: for i, generation in enumerate(response.generations[0]): - if isinstance(generation, ChatGeneration) and isinstance(generation.message,(AIMessageChunk, AIMessage)) and generation.message.usage_metadata: + if isinstance(generation, ChatGeneration) and isinstance(generation.message, ( + AIMessageChunk, AIMessage)) and generation.message.usage_metadata: is_get_from_langchain = True result['input_tokens'] = generation.message.usage_metadata.get('input_tokens', 0) result['output_tokens'] = generation.message.usage_metadata.get('output_tokens', 0) result['tokens'] = result['input_tokens'] + result['output_tokens'] if generation.message.usage_metadata.get('output_token_details', {}): - reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get('reasoning', 0) + reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get( + 'reasoning', 0) if reasoning_tokens: result['reasoning_tokens'] = reasoning_tokens if generation.message.usage_metadata.get('input_token_details', {}): - input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_read', 0) + input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get( + 'cache_read', 0) if input_read_cached_tokens: result['input_cached_tokens'] = input_read_cached_tokens - input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_creation', 0) + input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details', + {}).get('cache_creation', + 0) if input_creation_cached_tokens: result['input_creation_cached_tokens'] = input_creation_cached_tokens if is_get_from_langchain: @@ -259,7 +307,8 @@ def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, span_tags = {'error_info': repr(e), 'error_trace': traceback.format_exc()} return span_tags - def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None: + def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), + **kwargs: Any) -> None: # get inputs params: List[Argument] = [] if isinstance(inputs, str): @@ -309,8 +358,14 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span: span_name = node_name # set parent span parent_span: Span = None + is_root_span = False if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None and str(kwargs['parent_run_id']) in self.run_map: parent_span = self.run_map[str(kwargs['parent_run_id'])].span + # only root span use child_of + if parent_span is None: + is_root_span = True + if self._child_of: + parent_span = self._child_of # modify name error_tag = {} try: @@ -321,15 +376,20 @@ def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span: except Exception as e: error_tag = {'error_info': f'name_fn error {repr(e)}', 'error_trace': traceback.format_exc()} # new span - flow_span = _trace_callback_client.start_span(span_name, span_type, child_of=parent_span) + flow_span = self._client.start_span(span_name, span_type, child_of=parent_span) + if is_root_span: + if hasattr(flow_span, 'context'): + self.root_span_id = flow_span.context.span_id + else: + self.trace_id = flow_span.span_id run_id = str(kwargs['run_id']) self.run_map[run_id] = Run(run_id, flow_span, span_type) # set runtime flow_span.set_runtime(RuntimeInfo()) # set extra tags - flow_span.set_tags(self._tags) # global tags + flow_span.set_tags(self._tags) # global tags try: - if self.tags_fn: # add tags fn + if self.tags_fn: # add tags fn tags = self.tags_fn(node_name) if isinstance(tags, dict): flow_span.set_tags(tags) @@ -365,7 +425,10 @@ def _set_extra_span_tags(self, flow_span: Span, tag_list: list, **kwargs: Any): class Run: def __init__(self, run_id: str, span: Span, span_type: str) -> None: self.run_id = run_id # langchain run_id - self.span_id = span.span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping. + if hasattr(span, 'context'): + self.span_id = span.context.span_id + else: + self.span_id = span.span_id # loop span_id,the relationship between run_id and span_id is one-to-one mapping. self.span = span self.span_type = span_type self.child_runs: List[Run] = Field(default_factory=list) @@ -519,7 +582,8 @@ def _convert_inputs(inputs: Any) -> Any: format_inputs['content'] = inputs.content return format_inputs if isinstance(inputs, BaseMessage): - message = Message(role=inputs.type, content=inputs.content, tool_calls=inputs.additional_kwargs.get('tool_calls', [])) + message = Message(role=inputs.type, content=inputs.content, + tool_calls=inputs.additional_kwargs.get('tool_calls', [])) return message if isinstance(inputs, ChatPromptValue): return _convert_inputs(inputs.messages) diff --git a/cozeloop/integration/langchain/trace_model/llm_model.py b/cozeloop/integration/langchain/trace_model/llm_model.py index 5e18cf0..3e12a86 100644 --- a/cozeloop/integration/langchain/trace_model/llm_model.py +++ b/cozeloop/integration/langchain/trace_model/llm_model.py @@ -72,7 +72,6 @@ def __post_init__(self): else: self.content = self.content.__str__() elif isinstance(self.content, dict): - is_part = False text = self.content.get('text', None) url = self.content.get('url', self.content.get('image_url', {}).get('url', None)) if text is not None or url is not None: @@ -96,6 +95,7 @@ class Choice: @dataclass class Choices: + id: Optional[str] = None choices: Optional[List[Choice]] = None @@ -103,6 +103,7 @@ class Choices: class ModelTraceInputData: messages: Optional[List[Message]] = None tools: Optional[List[Tool]] = None + previous_response_id: Optional[str] = None @dataclass @@ -178,8 +179,10 @@ def to_json(self): description=bind_function.get('description', ''), parameters=bind_function.get('parameters', {})) tools.append(Tool(type=bind_function.get('type', ''), function=function)) + + pre_resp_id = self._invocation_params.get('previous_response_id', None) return json.dumps( - ModelTraceInputData(messages=self._messages, tools=tools), + ModelTraceInputData(messages=self._messages, tools=tools, previous_response_id=pre_resp_id), default=lambda o: dict((key, value) for key, value in o.__dict__.items() if value), sort_keys=False, ensure_ascii=False) @@ -192,19 +195,31 @@ def __init__(self, generations: List[Union[ChatGeneration, Generation]]): def to_json(self): choices: List[Choice] = [] + response_id = None for i, generation in enumerate(self.generations): choice: Choice = None if isinstance(generation, ChatGeneration): message = convert_output_message(generation.message) - choice = Choice(index=i, message=message, finish_reason=generation.generation_info.get('finish_reason', '')) + if message and message.metadata: + response_id = message.metadata.get('id', None) + choice = Choice(index=i, message=message) + if generation.generation_info: + choice.finish_reason = generation.generation_info.get('finish_reason', '') elif isinstance(generation, Generation): choice = Choice(index=i, message=Message(content=generation.text)) choices.append(choice) - return json.dumps( - Choices(choices=choices), - default=lambda o: dict((key, value) for key, value in o.__dict__.items() if value or key=='index'), - sort_keys=False, - ensure_ascii=False) + res = '' + try: + res = json.dumps( + Choices(id=response_id, choices=choices), + default=lambda o: dict((key, value) for key, value in o.__dict__.items() if value or key == 'index'), + sort_keys=False, + ensure_ascii=False) + except Exception as e: + logging.error(f"ModelTraceOutput.to_json failed, exception: {e}, choices: {choices}") + raise e + finally: + return res def convert_tool_calls_by_raw(tool_calls: list) -> List[ToolCall]: diff --git a/cozeloop/integration/langchain/trace_model/runtime.py b/cozeloop/integration/langchain/trace_model/runtime.py index 27ae6d8..282df7c 100644 --- a/cozeloop/integration/langchain/trace_model/runtime.py +++ b/cozeloop/integration/langchain/trace_model/runtime.py @@ -7,18 +7,25 @@ from cozeloop.spec import tracespec +LANGCHAIN_VERSION = '' +LANGCHAIN_CORE_VERSION = '' +try: + LANGCHAIN_VERSION = metadata.version('langchain') +except metadata.PackageNotFoundError: + LANGCHAIN_VERSION = '' +try: + LANGCHAIN_CORE_VERSION = metadata.version('langchain-core') +except metadata.PackageNotFoundError: + LANGCHAIN_CORE_VERSION = '' + class RuntimeInfo(tracespec.Runtime): language: Optional[str] = tracespec.V_LANG_PYTHON library: Optional[str] = tracespec.V_LIB_LANGCHAIN def model_post_init(self, context: Any) -> None: - try: - langchain_version = metadata.version('langchain') - except metadata.PackageNotFoundError: - langchain_version = '' - - self.library_version = langchain_version + self.library_version = LANGCHAIN_VERSION + self.extra = {'langchain_core_version': LANGCHAIN_CORE_VERSION} def to_json(self): return json.dumps(