Skip to content
54 changes: 53 additions & 1 deletion sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextvars
import itertools
from collections import OrderedDict
from functools import wraps
Expand Down Expand Up @@ -72,6 +73,40 @@
}


# Contextvar to track agent names in a stack for re-entrant agent support
_agent_stack = contextvars.ContextVar("langchain_agent_stack", default=None)


def _push_agent(agent_name):
# type: (Optional[str]) -> None
"""Push an agent name onto the stack."""
stack = _agent_stack.get()
if stack is None:
stack = []
stack.append(agent_name)
_agent_stack.set(stack)


def _pop_agent():
# type: () -> Optional[str]
"""Pop an agent name from the stack and return it."""
stack = _agent_stack.get()
if stack and len(stack) > 0:
agent_name = stack.pop()
_agent_stack.set(stack)
return agent_name
return None


def _get_current_agent():
# type: () -> Optional[str]
"""Get the current agent name (top of stack) without removing it."""
stack = _agent_stack.get()
if stack and len(stack) > 0:
return stack[-1]
return None


class LangchainIntegration(Integration):
identifier = "langchain"
origin = f"auto.ai.{identifier}"
Expand Down Expand Up @@ -276,6 +311,10 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
elif "openai" in ai_type:
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")

agent_name = _get_current_agent()
if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

for key, attribute in DATA_FIELDS.items():
if key in all_params and all_params[key] is not None:
set_data_normalized(span, attribute, all_params[key], unpack=False)
Expand Down Expand Up @@ -428,6 +467,10 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
if tool_description is not None:
span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool_description)

agent_name = _get_current_agent()
if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

if should_send_default_pii() and self.include_prompts:
set_data_normalized(
span,
Expand Down Expand Up @@ -756,6 +799,7 @@ def new_invoke(self, *args, **kwargs):
name=f"invoke_agent {agent_name}" if agent_name else "invoke_agent",
origin=LangchainIntegration.origin,
) as span:
_push_agent(agent_name)
if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

Expand Down Expand Up @@ -794,6 +838,8 @@ def new_invoke(self, *args, **kwargs):
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)

_pop_agent()

return result

return new_invoke
Expand All @@ -814,11 +860,13 @@ def new_stream(self, *args, **kwargs):

span = start_span_function(
op=OP.GEN_AI_INVOKE_AGENT,
name=f"invoke_agent {agent_name}".strip(),
name=f"invoke_agent {agent_name}" if agent_name else "invoke_agent",
origin=LangchainIntegration.origin,
)
span.__enter__()

_push_agent(agent_name)

if agent_name:
span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)

Expand Down Expand Up @@ -868,6 +916,8 @@ def new_iterator():
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)

_pop_agent()

span.__exit__(None, None, None)

async def new_iterator_async():
Expand All @@ -887,6 +937,8 @@ async def new_iterator_async():
):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)

_pop_agent()

span.__exit__(None, None, None)

if str(type(result)) == "<class 'async_generator'>":
Expand Down
Loading