diff --git a/examples/observability/simple_calculator_observability/configs/config-weave.yml b/examples/observability/simple_calculator_observability/configs/config-weave.yml index 1f5058619..b55ebe6b3 100644 --- a/examples/observability/simple_calculator_observability/configs/config-weave.yml +++ b/examples/observability/simple_calculator_observability/configs/config-weave.yml @@ -20,6 +20,13 @@ general: weave: _type: weave project: "nat-demo" + front_end: + _type: fastapi + endpoints: + - path: /chat_feedback + method: POST + description: Set reaction feedback for an assistant message via Weave call ID + function_name: chat_feedback functions: calculator_multiply: @@ -32,6 +39,8 @@ functions: _type: current_datetime calculator_subtract: _type: calculator_subtract + chat_feedback: + _type: chat_feedback llms: nim_llm: diff --git a/packages/nvidia_nat_weave/src/nat/plugins/weave/weave_exporter.py b/packages/nvidia_nat_weave/src/nat/plugins/weave/weave_exporter.py index a1ce63166..1035f9084 100644 --- a/packages/nvidia_nat_weave/src/nat/plugins/weave/weave_exporter.py +++ b/packages/nvidia_nat_weave/src/nat/plugins/weave/weave_exporter.py @@ -17,6 +17,7 @@ from collections.abc import Generator from contextlib import contextmanager +from nat.builder.context import Context from nat.data_models.intermediate_step import IntermediateStep from nat.data_models.span import Span from nat.observability.exporter.base_exporter import IsolatedAttribute @@ -80,7 +81,18 @@ def _process_start_event(self, event: IntermediateStep): if span is None: logger.warning("No span found for event %s", event.UUID) return - self._create_weave_call(event, span) + call = self._create_weave_call(event, span) + + # capture the call ID for mapping reaction feedbacks to specific traces + if (event.payload.event_type == "FUNCTION_START" and event.payload.name == ""): + try: + # Store the workflow call ID in the context for later retrieval + context = Context.get() + context._context_state.trace_id.set(call.id) + logger.info("DEBUG: Captured workflow weave call ID: %s", call.id) + + except Exception as e: + logger.debug("Could not store workflow trace ID: %s", e) def _process_end_event(self, event: IntermediateStep): """Process the end event for a Weave call. diff --git a/src/nat/agent/react_agent/register.py b/src/nat/agent/react_agent/register.py index f09261d90..cd863ad42 100644 --- a/src/nat/agent/react_agent/register.py +++ b/src/nat/agent/react_agent/register.py @@ -20,6 +20,7 @@ from pydantic import PositiveInt from nat.builder.builder import Builder +from nat.builder.context import Context from nat.builder.framework_enum import LLMFrameworkEnum from nat.builder.function_info import FunctionInfo from nat.cli.register_workflow import register_function @@ -115,6 +116,10 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph() async def _response_fn(input_message: ChatRequest) -> ChatResponse: + # Get the trace ID for feedback tracking + context = Context.get() + trace_id = context.trace_id + try: # initialize the starting state with the user query messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages], @@ -135,14 +140,14 @@ async def _response_fn(input_message: ChatRequest) -> ChatResponse: # get and return the output from the state state = ReActGraphState(**state) output_message = state.messages[-1] - return ChatResponse.from_string(str(output_message.content)) + return ChatResponse.from_string(str(output_message.content), trace_id=trace_id) except Exception as ex: logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex) # here, we can implement custom error messages if config.verbose: - return ChatResponse.from_string(str(ex)) - return ChatResponse.from_string("I seem to be having a problem.") + return ChatResponse.from_string(str(ex), trace_id=trace_id) + return ChatResponse.from_string("I seem to be having a problem.", trace_id=trace_id) if (config.use_openai_api): yield FunctionInfo.from_fn(_response_fn, description=config.description) diff --git a/src/nat/builder/context.py b/src/nat/builder/context.py index 429b87675..b8a8ff3c3 100644 --- a/src/nat/builder/context.py +++ b/src/nat/builder/context.py @@ -67,6 +67,7 @@ class ContextState(metaclass=Singleton): def __init__(self): self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None) self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None) + self.trace_id: ContextVar[str | None] = ContextVar("trace_id", default=None) self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None) self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None) self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes()) @@ -174,6 +175,19 @@ def user_message_id(self) -> str | None: """ return self._context_state.user_message_id.get() + @property + def trace_id(self) -> str | None: + """ + Retrieves the trace ID from the current context state. + + This can be used to identify traces across different tracing systems + (e.g., Weave call IDs, Phoenix Trace IDs, OpenTelemetry trace IDs, etc.). + + Returns: + str | None: The trace ID if available, None otherwise. + """ + return self._context_state.trace_id.get() + @contextmanager def push_active_function(self, function_name: str, diff --git a/src/nat/data_models/api_server.py b/src/nat/data_models/api_server.py index 680c61afd..0d1810f24 100644 --- a/src/nat/data_models/api_server.py +++ b/src/nat/data_models/api_server.py @@ -251,6 +251,7 @@ class ChatResponse(ResponseBaseModelOutput): usage: Usage | None = None system_fingerprint: str | None = None service_tier: typing.Literal["scale", "default"] | None = None + trace_id: str | None = None @field_serializer('created') def serialize_created(self, created: datetime.datetime) -> int: @@ -264,7 +265,8 @@ def from_string(data: str, object_: str | None = None, model: str | None = None, created: datetime.datetime | None = None, - usage: Usage | None = None) -> "ChatResponse": + usage: Usage | None = None, + trace_id: str | None = None) -> "ChatResponse": if id_ is None: id_ = str(uuid.uuid4()) @@ -280,7 +282,8 @@ def from_string(data: str, model=model, created=created, choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")], - usage=usage) + usage=usage, + trace_id=trace_id) class ChatResponseChunk(ResponseBaseModelOutput): @@ -300,6 +303,7 @@ class ChatResponseChunk(ResponseBaseModelOutput): system_fingerprint: str | None = None service_tier: typing.Literal["scale", "default"] | None = None usage: Usage | None = None + trace_id: str | None = None @field_serializer('created') def serialize_created(self, created: datetime.datetime) -> int: @@ -312,7 +316,8 @@ def from_string(data: str, id_: str | None = None, created: datetime.datetime | None = None, model: str | None = None, - object_: str | None = None) -> "ChatResponseChunk": + object_: str | None = None, + trace_id: str | None = None) -> "ChatResponseChunk": if id_ is None: id_ = str(uuid.uuid4()) @@ -327,7 +332,8 @@ def from_string(data: str, choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")], created=created, model=model, - object=object_) + object=object_, + trace_id=trace_id) @staticmethod def create_streaming_chunk(content: str, @@ -338,7 +344,8 @@ def create_streaming_chunk(content: str, role: str | None = None, finish_reason: str | None = None, usage: Usage | None = None, - system_fingerprint: str | None = None) -> "ChatResponseChunk": + system_fingerprint: str | None = None, + trace_id: str | None = None) -> "ChatResponseChunk": """Create an OpenAI-compatible streaming chunk""" if id_ is None: id_ = str(uuid.uuid4()) @@ -358,7 +365,8 @@ def create_streaming_chunk(content: str, model=model, object="chat.completion.chunk", usage=usage, - system_fingerprint=system_fingerprint) + system_fingerprint=system_fingerprint, + trace_id=trace_id) class ResponseIntermediateStep(ResponseBaseModelIntermediate): @@ -631,7 +639,11 @@ def _string_to_nat_chat_request(data: str) -> ChatRequest: # ======== ChatResponse Converters ======== def _nat_chat_response_to_string(data: ChatResponse) -> str: if data.choices and data.choices[0].message: - return data.choices[0].message.content or "" + content = data.choices[0].message.content or "" + # Include trace ID in the string if available, using a special format + if data.trace_id: + return f"{content}__TRACE_ID__:{data.trace_id}" + return content return "" @@ -656,7 +668,11 @@ def _string_to_nat_chat_response(data: str) -> ChatResponse: def _chat_response_to_chat_response_chunk(data: ChatResponse) -> ChatResponseChunk: # Preserve original message structure for backward compatibility - return ChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model) + return ChatResponseChunk(id=data.id, + choices=data.choices, + created=data.created, + model=data.model, + trace_id=data.trace_id) GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk) @@ -679,8 +695,18 @@ def _chat_response_chunk_to_string(data: ChatResponseChunk) -> str: def _string_to_nat_chat_response_chunk(data: str) -> ChatResponseChunk: '''Converts a string to an ChatResponseChunk object''' + # Check if the string contains embedded trace ID + trace_id = None + content = data + + if "__TRACE_ID__:" in data: + parts = data.split("__TRACE_ID__:") + if len(parts) == 2: + content = parts[0] + trace_id = parts[1] + # Build and return the response - return ChatResponseChunk.from_string(data) + return ChatResponseChunk.from_string(content, trace_id=trace_id) GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk) diff --git a/src/nat/tool/chat_feedback.py b/src/nat/tool/chat_feedback.py new file mode 100644 index 000000000..582046fa5 --- /dev/null +++ b/src/nat/tool/chat_feedback.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + + +class ChatFeedbackTool(FunctionBaseConfig, name="chat_feedback"): + """ + A tool that allows adding reactions/feedback to Weave calls. This tool retrieves a Weave call + by its ID and adds a reaction (like thumbs up/down) to provide feedback on the call's output. + The tool automatically configures the Weave project from the builder's telemetry exporters. + """ + pass + + +@register_function(config_type=ChatFeedbackTool) +async def chat_feedback(config: ChatFeedbackTool, builder: Builder): + + async def _add_chat_feedback(weave_call_id: str, reaction_type: str) -> str: + import weave + + # Get the weave project configuration from the builder's telemetry exporters + weave_project = None + + # Handle both ChildBuilder and WorkflowBuilder + workflow_builder = getattr(builder, '_workflow_builder', builder) + + if hasattr(workflow_builder, '_telemetry_exporters'): + for exporter_config in workflow_builder._telemetry_exporters.values(): + if hasattr(exporter_config.config, 'project'): + # Construct project name in the same format as the weave exporter + entity = getattr(exporter_config.config, 'entity', None) + project = exporter_config.config.project + weave_project = f"{entity}/{project}" if entity else project + break + + client = weave.init(weave_project) + call = client.get_call(weave_call_id) + call.feedback.add_reaction(reaction_type) + + return f"Added reaction '{reaction_type}' to call {weave_call_id}" + + yield FunctionInfo.from_fn( + _add_chat_feedback, + description="Adds a reaction/feedback to a Weave call using the provided call ID and reaction type.") diff --git a/src/nat/tool/register.py b/src/nat/tool/register.py index fa6887143..238a8c947 100644 --- a/src/nat/tool/register.py +++ b/src/nat/tool/register.py @@ -17,6 +17,7 @@ # Import any tools which need to be automatically registered here from . import chat_completion +from . import chat_feedback from . import datetime_tools from . import document_search from . import github_tools