diff --git a/nemoguardrails/actions/action_dispatcher.py b/nemoguardrails/actions/action_dispatcher.py index 67eef91cd..3fe066b4b 100644 --- a/nemoguardrails/actions/action_dispatcher.py +++ b/nemoguardrails/actions/action_dispatcher.py @@ -33,6 +33,8 @@ class ActionDispatcher: + """Manages the execution and life time of local actions.""" + def __init__( self, load_all_actions: bool = True, @@ -51,7 +53,8 @@ def __init__( """ log.info("Initializing action dispatcher") - self._registered_actions = {} + # Dictionary with all registered actions + self._registered_actions: dict = {} if load_all_actions: # TODO: check for better way to find actions dir path or use constants.py @@ -87,7 +90,7 @@ def __init__( for import_path in import_paths: self.load_actions_from_path(Path(import_path.strip())) - log.info(f"Registered Actions :: {sorted(self._registered_actions.keys())}") + log.info("Registered Actions :: %s", sorted(self._registered_actions.keys())) log.info("Action dispatcher initialized") @property @@ -181,7 +184,7 @@ def get_action(self, name: str) -> callable: async def execute_action( self, action_name: str, params: Dict[str, Any] - ) -> Tuple[Union[str, Dict[str, Any]], str]: + ) -> Tuple[Optional[Union[str, Dict[str, Any]]], str]: """Execute a registered action. Args: @@ -195,7 +198,7 @@ async def execute_action( action_name = self._normalize_action_name(action_name) if action_name in self._registered_actions: - log.info(f"Executing registered action: {action_name}") + log.info("Executing registered action: %s", action_name) fn = self._registered_actions.get(action_name, None) # Actions that are registered as classes are initialized lazy, when @@ -214,7 +217,7 @@ async def execute_action( result = await result else: log.warning( - f"Synchronous action `{action_name}` has been called." + "Synchronous action `%s` has been called.", action_name ) elif isinstance(fn, Chain): @@ -256,15 +259,12 @@ async def execute_action( filtered_params = { k: v for k, v in params.items() - if k not in ["state", "events", "llm"] + if k not in ["state", "events", "llm", "event_handler"] } - log.warning( - "Error while execution '%s' with parameters '%s': %s", - action_name, - filtered_params, - e, + msg = ( + f"Exception while execution '{action_name}' with parameters '{filtered_params}'", ) - log.exception(e) + raise Exception(f"{msg}: {e}") from e return None, "failed" diff --git a/nemoguardrails/cli/chat.py b/nemoguardrails/cli/chat.py index c48997319..dd9f2bf9b 100644 --- a/nemoguardrails/cli/chat.py +++ b/nemoguardrails/cli/chat.py @@ -498,7 +498,9 @@ async def _process_input_events(): chat_state.input_events = [] else: chat_state.waiting_user_input = True - await enable_input.wait() + # NOTE: We should never disable the user input since we can have + # async Python actions running in parallel + # await enable_input.wait() user_message: str = await chat_state.session.prompt_async( HTML("\n> "), diff --git a/nemoguardrails/colang/runtime.py b/nemoguardrails/colang/runtime.py index ba61eaaf5..98c5b2b93 100644 --- a/nemoguardrails/colang/runtime.py +++ b/nemoguardrails/colang/runtime.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging from abc import abstractmethod from typing import Any, Callable, List, Optional, Tuple @@ -47,7 +48,7 @@ def __init__(self, config: RailsConfig, verbose: bool = False): # A set of watchers that are notified every time an event is processed. # Used mainly for reporting the progress to the CLI. - self.watchers = [] + self.watchers: List = [] # The maximum number of events to be processed in a processing loop self.max_events = 500 diff --git a/nemoguardrails/colang/v1_0/runtime/runtime.py b/nemoguardrails/colang/v1_0/runtime/runtime.py index 56fa00efc..e3d426706 100644 --- a/nemoguardrails/colang/v1_0/runtime/runtime.py +++ b/nemoguardrails/colang/v1_0/runtime/runtime.py @@ -25,6 +25,7 @@ from langchain.chains.base import Chain from nemoguardrails.actions.actions import ActionResult +from nemoguardrails.actions.llm.utils import LLMCallException from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.runtime import Runtime from nemoguardrails.colang.v1_0.runtime.flows import ( @@ -360,9 +361,16 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]: kwargs["llm"] = self.registered_action_params[f"{action_name}_llm"] log.info("Executing action :: %s", action_name) - result, status = await self.action_dispatcher.execute_action( - action_name, kwargs - ) + try: + result, status = await self.action_dispatcher.execute_action( + action_name, kwargs + ) + except LLMCallException as e: + raise e + except Exception as e: + result = None + status = "failed" + log.exception(e) # If the action execution failed, we return a hardcoded message if status == "failed": diff --git a/nemoguardrails/colang/v2_x/runtime/flows.py b/nemoguardrails/colang/v2_x/runtime/flows.py index 053f43e65..b79c0a53d 100644 --- a/nemoguardrails/colang/v2_x/runtime/flows.py +++ b/nemoguardrails/colang/v2_x/runtime/flows.py @@ -23,7 +23,18 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Callable, Deque, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + ClassVar, + Deque, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) from dataclasses_json import dataclass_json @@ -33,7 +44,7 @@ FlowReturnMemberDef, ) from nemoguardrails.colang.v2_x.runtime.errors import ColangSyntaxError -from nemoguardrails.utils import new_readable_uuid, new_uuid +from nemoguardrails.utils import new_event_dict, new_readable_uuid, new_uuid log = logging.getLogger(__name__) @@ -108,6 +119,15 @@ def from_umim_event(cls, event: dict) -> Event: ) return new_event + def to_umim_event(self, event_source_uid: Optional[str] = None) -> Dict[str, Any]: + """Return a umim event dictionary.""" + new_event_args = dict(self.arguments) + new_event_args.setdefault( + "source_uid", + event_source_uid if event_source_uid else "NeMoGuardrails-Colang-2.x", + ) + return new_event_dict(self.name, **new_event_args) + # Expose all event parameters as attributes of the event def __getattr__(self, name): if ( @@ -150,6 +170,20 @@ def from_umim_event(cls, event: dict) -> ActionEvent: new_event.action_uid = event["action_uid"] return new_event + def to_umim_event(self, event_source_uid: Optional[str] = None) -> Dict[str, Any]: + """Return a umim event dictionary.""" + new_event_args = dict(self.arguments) + new_event_args.setdefault( + "source_uid", + event_source_uid if event_source_uid else "NeMoGuardrails-Colang-2.x", + ) + if self.action_uid and "action_uid" not in new_event_args: + return new_event_dict( + self.name, action_uid=self.action_uid, **new_event_args + ) + else: + return new_event_dict(self.name, **new_event_args) + class ActionStatus(Enum): """The status of an action.""" @@ -176,19 +210,43 @@ class Action: "Stop": "stop_event", } + # List of umim specific parameters + _umim_parameters: ClassVar[List[str]] = [ + "type", + "uid", + "event_created_at", + "source_uid", + "action_uid", + "action_info_modality", + "action_info_modality_policy", + "action_finished_at", + ] + @classmethod def from_event(cls, event: ActionEvent) -> Optional[Action]: """Returns the action if event name conforms with UMIM convention.""" assert event.action_uid is not None for name in cls._event_name_map: if name in event.name: - action = Action(event.name.replace(name, ""), {}) + action_name: str + if name == "Updated": + index = event.name.find("Action") + 6 + action_name = event.name[:index] + else: + action_name = event.name.replace(name, "") + action = Action(action_name, {}) action.uid = event.action_uid action.status = ( ActionStatus.STARTED if name != "Finished" else ActionStatus.FINISHED ) + if name == "Start": + action.start_event_arguments = { + key: event.arguments[key] + for key in event.arguments + if key not in cls._umim_parameters + } return action return None @@ -288,7 +346,7 @@ def start_event(self, _args: dict) -> ActionEvent: def change_event(self, args: dict) -> ActionEvent: """Changes a parameter of a started action.""" return ActionEvent( - name=f"Change{self.name}", arguments=args["arguments"], action_uid=self.uid + name=f"Change{self.name}", arguments=args, action_uid=self.uid ) def stop_event(self, _args: dict) -> ActionEvent: @@ -355,7 +413,7 @@ class FlowConfig: id: str # The sequence of elements that compose the flow. - elements: List[ElementType] + elements: Sequence[ElementType] # The flow parameters parameters: List[FlowParamDef] diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index 20044b8a6..3a3f07a88 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -16,7 +16,9 @@ import inspect import logging import re -from typing import Any, Dict, List, Optional, Tuple, Union +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from urllib.parse import urljoin import aiohttp @@ -32,7 +34,12 @@ ColangRuntimeError, ColangSyntaxError, ) -from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus +from nemoguardrails.colang.v2_x.runtime.flows import ( + Action, + ActionEvent, + Event, + FlowStatus, +) from nemoguardrails.colang.v2_x.runtime.statemachine import ( FlowConfig, InternalEvent, @@ -50,23 +57,173 @@ log = logging.getLogger(__name__) +class ActionEventHandler: + """Handles input and output events to Python actions.""" + + _lock = asyncio.Lock() + + def __init__( + self, + config: RailsConfig, + action: Action, + event_input_queue: asyncio.Queue[dict], + event_output_queue: asyncio.Queue[dict], + ): + # The LLMRails config + self._config = config + + # The relevant action + self._action = action + + # Action specific action event queue for event receiving + self._event_input_queue = event_input_queue + + # Shared async action event queue for event sending + self._event_output_queue = event_output_queue + + def send_action_updated_event( + self, event_name: str, args: Optional[dict] = None + ) -> None: + """ + Send an Action*Updated event. + + Args: + event_name (str): The name of the action event, e.g. `Attention` for AttentionUserActionUpdated + args (Optional[dict]): An optional dictionary with the event arguments + """ + + if args: + args = {"event_parameter_name": event_name, **args} + else: + args = {"event_parameter_name": event_name} + action_event = self._action.updated_event(args) + self._event_output_queue.put_nowait( + action_event.to_umim_event(self._config.event_source_uid) + ) + + async def wait_for_change_action_event( + self, timeout: Optional[float] = None + ) -> Optional[dict]: + """ + Wait for new Change*Action event. + + Args: + timeout (Optional[float]): The time to wait for the new event before it continues + """ + return await self.wait_for_event(self._action.change_event({}).name, timeout) + + def send_event(self, event_name: str, args: Optional[dict] = None) -> None: + """ + Send any event. + + Args: + event_name (str): The event name + args (Optional[dict]): An optional dictionary with the event arguments + """ + event = Event(event_name, args if args else {}) + self._event_output_queue.put_nowait( + event.to_umim_event(self._config.event_source_uid) + ) + + async def wait_for_event( + self, event_name: Optional[str] = None, timeout: Optional[float] = None + ) -> Optional[dict]: + """ + Wait for next new input event to process. + + Args: + event_name (Optional[str]): Optional event name to filter for, if None all events will be received + timeout (Optional[float]): The time to wait for new events before it continues + """ + keep_waiting = True + start_time = time.time() + while keep_waiting: + try: + # Check cumulative waiting time + if timeout and time.time() - start_time > timeout: + raise asyncio.TimeoutError() + # Wait for next event + event = await asyncio.wait_for(self._event_input_queue.get(), timeout) + if event_name is None or event["type"] == event_name: + return event + except asyncio.TimeoutError: + # Timeout occurred, stop consuming + keep_waiting = False + return None + + async def wait_for_events( + self, event_name: Optional[str] = None, timeout: Optional[float] = None + ) -> List[dict]: + """ + Wait for new input events to process. + + Args: + event_name (Optional[str]): Optional event name to filter for, if None all events will be received + timeout (Optional[float]): The time to wait for new events before it continues + """ + events: List[dict] = [] + keep_waiting = True + start_time = time.time() + while keep_waiting: + try: + # Check cumulative waiting time + if timeout and time.time() - start_time > timeout: + raise asyncio.TimeoutError() + # Wait for new events + event = await asyncio.wait_for(self._event_input_queue.get(), timeout) + # Gather all new events + while True: + if event_name is None or event["type"] == event_name: + events.append(event) + event = self._event_input_queue.get_nowait() + except asyncio.QueueEmpty: + self._event_input_queue.task_done() + keep_waiting = len(events) == 0 + except asyncio.TimeoutError: + # Timeout occurred, stop consuming + keep_waiting = False + return events + + +@dataclass +class LocalActionData: + """Structure to help organize action related data.""" + + # All active async action task + task: asyncio.Task + # The action's output event queue + input_event_queues: asyncio.Queue[dict] = field( + default_factory=lambda: asyncio.Queue() + ) + + +@dataclass +class LocalActionGroup: + """Structure to help organize all local actions related to a certain main flow.""" + + # Action uid ordered action data + action_data: Dict[str, LocalActionData] = field(default_factory=dict) + + # A single output event queue for all actions + output_event_queue: asyncio.Queue = field(default_factory=lambda: asyncio.Queue()) + + class RuntimeV2_x(Runtime): """Runtime for executing the guardrails.""" def __init__(self, config: RailsConfig, verbose: bool = False): super().__init__(config, verbose) - # Register local system actions - self.register_action(self._add_flows_action, "AddFlowsAction", False) - self.register_action(self._remove_flows_action, "RemoveFlowsAction", False) - - # Maps main_flow.uid to a dictionary of actions that are run locally, asynchronously. - # Dict[main_flow_uid, Dict[action_uid, action_data]] - self.async_actions: Dict[str, List] = {} + # Maps main_flow.uid to a list of action group data that contains all the locally running actions. + self.local_actions: Dict[str, LocalActionGroup] = {} # A way to disable async function execution. Useful for testing. self.disable_async_execution = False + # Register local system actions + self.register_action(self._add_flows_action, "AddFlowsAction", False) + self.register_action(self._remove_flows_action, "RemoveFlowsAction", False) + async def _add_flows_action(self, state: "State", **args: dict) -> List[str]: log.info("Start AddFlowsAction! %s", args) flow_content = args["config"] @@ -147,14 +304,19 @@ async def _remove_flows_action(self, state: "State", **args: dict) -> None: def _init_flow_configs(self) -> None: """Initializes the flow configs based on the config.""" - self.flow_configs = create_flow_configs_from_flow_list(self.config.flows) + self.flow_configs = create_flow_configs_from_flow_list( + cast(List[Flow], self.config.flows) + ) - async def generate_events(self, events: List[dict]) -> List[dict]: + async def generate_events( + self, events: List[dict], processing_log: Optional[List[dict]] = None + ) -> List[dict]: raise NotImplementedError("Stateless API not supported for Colang 2.x, yet.") @staticmethod def _internal_error_action_result(message: str) -> ActionResult: """Helper to construct an action result for an internal error.""" + # TODO: We should handle this as an ActionFinished(is_success=False) event and not generate custom other events return ActionResult( events=[ { @@ -173,24 +335,24 @@ def _internal_error_action_result(message: str) -> ActionResult: async def _process_start_action( self, - action_name: str, - action_params: dict, + action: Action, context: dict, - events: List[dict], state: "State", ) -> Tuple[Any, List[dict], dict]: """Starts the specified action, waits for it to finish and posts back the result.""" - fn = self.action_dispatcher.get_action(action_name) + return_value: Any = None + return_events: List[dict] = [] + context_updates: dict = {} + + fn = self.action_dispatcher.get_action(action.name) # TODO: check action is available in action server if fn is None: - result = self._internal_error_action_result( - f"Action '{action_name}' not found." - ) + raise ColangRuntimeError(f"Action '{action.name}' not found.") else: # We pass all the parameters that are passed explicitly to the action. - kwargs = {**action_params} + kwargs = {**action.start_event_arguments} action_meta = getattr(fn, "action_meta", {}) @@ -228,13 +390,28 @@ async def _process_start_action( and action_type != "chain" ): result, status = await self._get_action_resp( - action_meta, action_name, kwargs + action_meta, action.name, kwargs ) else: # We don't send these to the actions server; # TODO: determine if we should if "events" in parameters: - kwargs["events"] = events + kwargs["events"] = state.last_events + + if "event_handler" in parameters: + kwargs["event_handler"] = ActionEventHandler( + self.config, + action, + self.local_actions[state.main_flow_state.uid] + .action_data[action.uid] + .input_event_queues, + self.local_actions[ + state.main_flow_state.uid + ].output_event_queue, + ) + + if "action" in parameters: + kwargs["action"] = action if "context" in parameters: kwargs["context"] = context @@ -255,25 +432,31 @@ async def _process_start_action( if ( "llm" in kwargs - and f"{action_name}_llm" in self.registered_action_params + and f"{action.name}_llm" in self.registered_action_params ): - kwargs["llm"] = self.registered_action_params[f"{action_name}_llm"] + kwargs["llm"] = self.registered_action_params[f"{action.name}_llm"] - log.info("Running action :: %s", action_name) + log.info("Running action :: %s", action.name) result, status = await self.action_dispatcher.execute_action( - action_name, kwargs + action.name, kwargs ) # If the action execution failed, we return a hardcoded message if status == "failed": - # TODO: make this message configurable. - result = self._internal_error_action_result( - "I'm sorry, an internal error has occurred." + action_finished_event = self._get_action_finished_event( + self.config, + action, + status="failed", + is_success=False, + failure_reason="Local action finished with an exception!", ) + return_events.append(action_finished_event) + + # result = self._internal_error_action_result( + # "I'm sorry, an internal error has occurred." + # ) return_value = result - return_events: List[dict] = [] - context_updates: dict = {} if isinstance(result, ActionResult): return_value = result.return_value @@ -282,9 +465,6 @@ async def _process_start_action( if result.context_updates is not None: context_updates.update(result.context_updates) - # if return_events: - # next_steps.extend(return_events) - return return_value, return_events, context_updates async def _get_action_resp( @@ -317,10 +497,10 @@ async def _get_action_resp( f"Got status code {resp.status} while getting response from {action_name}" ) - resp = await resp.json() + json_resp = await resp.json() result, status = ( - resp.get("result", result), - resp.get("status", status), + json_resp.get("result", result), + json_resp.get("status", status), ) except Exception as e: log.info( @@ -337,20 +517,42 @@ async def _get_action_resp( return result, status @staticmethod - def _get_action_finished_event(result: dict, **kwargs) -> Dict[str, Any]: - """Helper to return the ActionFinished event from the result of running a local action.""" - return new_event_dict( - f"{result['action_name']}Finished", - action_uid=result["start_action_event"]["action_uid"], - action_name=result["action_name"], - status="success", - is_success=True, - return_value=result["return_value"], - events=result["new_events"], - **kwargs, - # is_system_action=action_meta.get("is_system_action", False), + def _get_action_finished_event( + rails_config: RailsConfig, + action: Action, + **kwargs, + ) -> Dict[str, Any]: + """Helper to augment the ActionFinished event with additional data.""" + if "return_value" not in kwargs: + kwargs["return_value"] = None + if "events" not in kwargs: + kwargs["events"] = [] + event = action.finished_event( + { + "action_name": action.name, + "status": "success", + "is_success": True, + **kwargs, + } ) + return event.to_umim_event(rails_config.event_source_uid) + + async def _get_async_action_events(self, main_flow_uid: str) -> List[dict]: + events = [] + while True: + try: + # Attempt to get an item from the queue without waiting + event = self.local_actions[ + main_flow_uid + ].output_event_queue.get_nowait() + + events.append(event) + except asyncio.QueueEmpty: + # Break the loop if the queue is empty + break + return events + async def _get_async_actions_finished_events( self, main_flow_uid: str ) -> Tuple[List[dict], int]: @@ -364,10 +566,13 @@ async def _get_async_actions_finished_events( The array of *ActionFinished events and the pending counter """ - pending_actions = self.async_actions.get(main_flow_uid, []) - if len(pending_actions) == 0: + local_action_group = self.local_actions[main_flow_uid] + if len(local_action_group.action_data) == 0: return [], 0 + pending_actions = [ + data.task for data in local_action_group.action_data.values() + ] done, pending = await asyncio.wait( pending_actions, return_when=asyncio.FIRST_COMPLETED, @@ -378,19 +583,36 @@ async def _get_async_actions_finished_events( action_finished_events = [] for finished_task in done: + action = finished_task.action # type: ignore try: result = finished_task.result() - except Exception: - log.warning( - "Local action finished with an exception!", - exc_info=True, + # We need to create the corresponding action finished event + action_finished_event = self._get_action_finished_event( + self.config, action, **result ) - - self.async_actions[main_flow_uid].remove(finished_task) - - # We need to create the corresponding action finished event - action_finished_event = self._get_action_finished_event(result) - action_finished_events.append(action_finished_event) + action_finished_events.append(action_finished_event) + except asyncio.CancelledError: + action_finished_event = self._get_action_finished_event( + self.config, + action, + status="failed", + is_success=False, + was_stopped=True, + failure_reason="stopped", + ) + action_finished_events.append(action_finished_event) + except Exception as e: + msg = "Local action finished with an exception!" + log.warning("%s %s", msg, e) + action_finished_event = self._get_action_finished_event( + self.config, + action, + status="failed", + is_success=False, + failure_reason=msg, + ) + action_finished_events.append(action_finished_event) + del self.local_actions[main_flow_uid].action_data[action.uid] return action_finished_events, len(pending) @@ -409,8 +631,8 @@ async def process_events( The events will be processed one by one, in the input order. If new events are generated as part of the processing, they will be appended to the input events. - By default, a processing cycle only waits for the local actions to finish, i.e, - if after processing all the input events, there are local actions in progress, the + By default, a processing cycle only waits for the non-async local actions to finish, i.e, + if after processing all the input events, there are non-async local actions in progress, the event processing will wait for them to finish. In blocking mode, the event processing will also wait for the local async actions. @@ -429,10 +651,19 @@ async def process_events( state. """ - output_events = [] - input_events: List[Union[dict, InternalEvent]] = events.copy() + output_events: List[Dict[str, Any]] = [] + input_events: List[Union[dict, InternalEvent]] = [] local_running_actions: List[asyncio.Task[dict]] = [] + def extend_input_events(events: Sequence[Union[dict, InternalEvent]]): + """Make sure to add all new input events to all local async action event queues.""" + input_events.extend(events) + for data in self.local_actions[main_flow_uid].action_data.values(): + for event in events: + if isinstance(event, dict) and event["type"] != "CheckLocalAsync": + data.input_event_queues.put_nowait(event) + + # Initialize empty state if state is None or state == {}: state = State( flow_states={}, flow_configs=self.flow_configs, rails_config=self.config @@ -454,6 +685,7 @@ async def process_events( input_event = InternalEvent(name="StartFlow", arguments={"flow_id": "main"}) input_events.insert(0, input_event) main_flow_state = state.flow_id_states["main"][-1] + self.local_actions[main_flow_state.uid] = LocalActionGroup() # Start all module level flows before main flow idx = 0 @@ -475,31 +707,46 @@ async def process_events( input_events.insert(0, input_event) idx += 1 + # Check if we have new async action events to add + new_events = await self._get_async_action_events(state.main_flow_state.uid) + extend_input_events(new_events) + output_events.extend(new_events) + # Check if we have new finished async local action events to add ( local_action_finished_events, - pending_local_async_action_counter, + pending_local_action_counter, ) = await self._get_async_actions_finished_events(main_flow_uid) - input_events.extend(local_action_finished_events) + extend_input_events(local_action_finished_events) + output_events.extend(local_action_finished_events) + local_action_finished_events = [] return_local_async_action_count = False - # While we have input events to process, or there are local running actions - # we continue the processing. + # Add all input events + extend_input_events(events) + + # While we have input events to process, or there are local + # (non-async) running actions we continue the processing. events_counter = 0 while input_events or local_running_actions: - new_outgoing_events = [] - for event in input_events: + while input_events: + event = input_events.pop(0) + events_counter += 1 if events_counter > self.max_events: log.critical( - f"Maximum number of events reached ({events_counter})!" + "Maximum number of events reached (%s)!", events_counter ) return output_events, state log.info("Processing event :: %s", event) for watcher in self.watchers: - watcher(event) + if ( + not isinstance(event, dict) + or event["type"] != "CheckLocalAsync" + ): + watcher(event) event_name = event["type"] if isinstance(event, dict) else event.name @@ -510,66 +757,33 @@ async def process_events( # Record the event that we're about to process state.last_events.append(event) - # Advance the state machine - new_event: Optional[Union[dict, Event]] = event - while new_event is not None: - try: - run_to_completion(state, new_event) - new_event = None - except Exception as e: - log.warning("Colang runtime error!", exc_info=True) - new_event = Event( - name="ColangError", - arguments={ - "type": str(type(e).__name__), - "error": str(e), - }, - ) - await asyncio.sleep(0.001) - - # If we have context updates after this event, we first add that. - # TODO: Check if this is still needed for e.g. stateless implementation - # if state.context_updates: - # output_events.append( - # new_event_dict("ContextUpdate", data=state.context_updates) - # ) - - for out_event in state.outgoing_events: - # We also record the out events in the recent history. - state.last_events.append(out_event) - - # We need to check if we need to run a locally registered action - start_action_match = re.match(r"Start(.*Action)", out_event["type"]) - if start_action_match: - action_name = start_action_match[1] + # Check if we need run a locally registered action + if isinstance(event, dict): + if re.match(r"Start(.*Action)", event["type"]): + action_event = ActionEvent.from_umim_event(event) + action = Action.from_event(action_event) + assert action # If it's an instant action, we finish it right away. - if instant_actions and action_name in instant_actions: - finished_event_data: dict = { - "action_name": action_name, - "start_action_event": out_event, - "return_value": None, - "new_events": [], - } - - # TODO: figure out a generic way of creating a compliant - # ...ActionFinished event - extra = {} - if action_name == "UtteranceBotAction": - extra["final_script"] = out_event["script"] + # TODO (schuellc): What is this needed for? + if instant_actions and action.name in instant_actions: + extra = {"action": action} + if action.name == "UtteranceBotAction": + extra["final_script"] = event["script"] action_finished_event = self._get_action_finished_event( - finished_event_data, **extra + self.config, **extra ) # We send the completion of the action as an output event # and continue processing it. + # TODO: Why do we need an output event for that? It should only be an new input event + extend_input_events([action_finished_event]) output_events.append(action_finished_event) - input_events.append(action_finished_event) - elif self.action_dispatcher.has_registered(action_name): + elif self.action_dispatcher.has_registered(action.name): # In this case we need to start the action locally - action_fn = self.action_dispatcher.get_action(action_name) + action_fn = self.action_dispatcher.get_action(action.name) execute_async = getattr(action_fn, "action_meta", {}).get( "execute_async", False ) @@ -577,12 +791,22 @@ async def process_events( # Start the local action local_action = asyncio.create_task( self._run_action( - action_name, - start_action_event=out_event, - events_history=state.last_events, + action, state=state, ) ) + # Attach related action to the task + local_action.action = action # type: ignore + + # Generate *ActionStarted event + action_started_event = action.started_event({}) + action_started_umim_event = ( + action_started_event.to_umim_event( + self.config.event_source_uid + ) + ) + extend_input_events([action_started_umim_event]) + output_events.append(action_started_umim_event) # If the function is not async, or async execution is disabled # we execute the actions as a local action. @@ -596,33 +820,67 @@ async def process_events( local_running_actions.append(local_action) else: main_flow_uid = state.main_flow_state.uid - if main_flow_uid not in self.async_actions: - self.async_actions[main_flow_uid] = [] - self.async_actions[main_flow_uid].append(local_action) - else: - output_events.append(out_event) - else: - output_events.append(out_event) - - # Check if we have new finished async local action events to add - ( - new_local_action_finished_events, - pending_local_async_action_counter, - ) = await self._get_async_actions_finished_events(main_flow_uid) - local_action_finished_events.extend(new_local_action_finished_events) - new_outgoing_events.extend(state.outgoing_events) + if main_flow_uid not in self.local_actions: + # TODO: This check should not be needed + self.local_actions[ + main_flow_uid + ] = LocalActionGroup() + self.local_actions[main_flow_uid].action_data.update( + {action.uid: LocalActionData(local_action)} + ) + elif re.match(r"Stop(.*Action)", event["type"]): + # Check if we need stop a locally running action + action_event = ActionEvent.from_umim_event(event) + action_uid = action_event.arguments.get("action_uid", None) + if action_uid: + data = self.local_actions[main_flow_uid].action_data.get( + action_uid + ) + if ( + data + and data.task.action.name # type: ignore + == action_event.name[4:] + ): + data.task.cancel() + + # Advance the state machine + new_event: Optional[Union[dict, Event]] = event + while new_event: + try: + run_to_completion(state, new_event) + new_event = None + except Exception as e: + log.warning("Colang runtime error!", exc_info=True) + new_event = Event( + name="ColangError", + arguments={ + "type": str(type(e).__name__), + "error": str(e), + }, + ) + # Give local async action the chance to process events + await asyncio.sleep(0.001) - input_events.clear() + # Add new async action events as new input events + new_events = await self._get_async_action_events( + state.main_flow_state.uid + ) + extend_input_events(new_events) + output_events.extend(new_events) - # If we have outgoing events we are also processing them as input events - if new_outgoing_events: - input_events.extend(new_outgoing_events) - continue + # Add new finished async local action events as new input events + ( + new_action_finished_events, + pending_local_action_counter, + ) = await self._get_async_actions_finished_events(main_flow_uid) + extend_input_events(new_action_finished_events) + output_events.extend(new_action_finished_events) - input_events.extend(local_action_finished_events) - local_action_finished_events = [] + # Add generated events as new input events + extend_input_events(state.outgoing_events) + output_events.extend(state.outgoing_events) - # If we have any local running actions, we need to wait for at least one + # If we have any non-async local running actions, we need to wait for at least one # of them to finish. if local_running_actions: log.info( @@ -639,7 +897,9 @@ async def process_events( result = finished_task.result() # We need to create the corresponding action finished event - action_finished_event = self._get_action_finished_event(result) + action_finished_event = self._get_action_finished_event( + self.config, finished_task.action, **result # type: ignore + ) input_events.append(action_finished_event) if return_local_async_action_count: @@ -650,55 +910,57 @@ async def process_events( ) output_events.append( new_event_dict( - "LocalAsyncCounter", counter=pending_local_async_action_counter + "LocalAsyncCounter", counter=pending_local_action_counter ) ) - # TODO: serialize the state to dict - # We cap the recent history to the last 500 state.last_events = state.last_events[-500:] - return output_events, state + if state.main_flow_state.status == FlowStatus.WAITING: + # Main flow is done, release related local action data + log.info("End of story!") + for item in self.local_actions[main_flow_uid].action_data.values(): + item.task.cancel() + self.local_actions[main_flow_uid].action_data.clear() + + # We currently filter out all events related local actions + # TODO: Consider if we should expose them all as umim events + final_output_events = [] + for event in output_events: + if isinstance(event, dict) and "action_uid" in event: + action_event = ActionEvent.from_umim_event(event) + action = Action.from_event(action_event) + if action and self.action_dispatcher.has_registered(action.name): + continue + final_output_events.append(event) + + return final_output_events, state async def _run_action( self, - action_name: str, - start_action_event: dict, - events_history: List[Union[dict, Event]], + action: Action, state: "State", ) -> dict: """Runs the locally registered action. Args - action_name: The name of the action to be executed. - start_action_event: The event that triggered the action. - events_history: The recent history of events that led to the action being triggered. + action: The action to be executed. + state: The state of the runtime. """ - # NOTE: To extract the actual parameters that should be passed to the local action, - # we ignore all the keys from "an empty event" of the same type. - ignore_keys = new_event_dict(start_action_event["type"]).keys() - action_params = { - k: v for k, v in start_action_event.items() if k not in ignore_keys - } - return_value, new_events, context_updates = await self._process_start_action( - action_name, - action_params=action_params, + action, context=state.context, - events=events_history, state=state, ) state.context.update(context_updates) return { - "action_name": action_name, "return_value": return_value, "new_events": new_events, "context_updates": context_updates, - "start_action_event": start_action_event, } diff --git a/nemoguardrails/colang/v2_x/runtime/statemachine.py b/nemoguardrails/colang/v2_x/runtime/statemachine.py index 430dfe043..9a0bdb19c 100644 --- a/nemoguardrails/colang/v2_x/runtime/statemachine.py +++ b/nemoguardrails/colang/v2_x/runtime/statemachine.py @@ -1836,7 +1836,9 @@ def _is_done_flow(flow_state: FlowState) -> bool: def _generate_umim_event(state: State, event: Event) -> Dict[str, Any]: - umim_event = create_umim_event(event, event.arguments, state.rails_config) + umim_event = event.to_umim_event( + state.rails_config.event_source_uid if state and state.rails_config else None + ) state.outgoing_events.append(umim_event) log.info("[bold violet]<- Action[/]: %s", event) @@ -2392,23 +2394,6 @@ def create_internal_event( return event -def create_umim_event( - event: Event, event_args: Dict[str, Any], config: Optional[RailsConfig] -) -> Dict[str, Any]: - """Returns an outgoing UMIM event for the provided action data""" - new_event_args = dict(event_args) - new_event_args.setdefault( - "source_uid", config.event_source_uid if config else "NeMoGuardrails-Colang-2.x" - ) - if isinstance(event, ActionEvent) and event.action_uid is not None: - if "action_uid" in new_event_args: - event.action_uid = new_event_args["action_uid"] - del new_event_args["action_uid"] - return new_event_dict(event.name, action_uid=event.action_uid, **new_event_args) - else: - return new_event_dict(event.name, **new_event_args) - - def _get_eval_context(state: State, flow_state: FlowState) -> dict: context = flow_state.context.copy() # Link global variables diff --git a/nemoguardrails/logging/verbose.py b/nemoguardrails/logging/verbose.py index a2f972238..747075cb2 100644 --- a/nemoguardrails/logging/verbose.py +++ b/nemoguardrails/logging/verbose.py @@ -163,6 +163,10 @@ def emit(self, record) -> None: msg += f"[dim]{title}[/]" console.print(msg, highlight=False, no_wrap=False) + elif record.levelno >= logging.WARNING: + current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] + msg = f"[dim]{current_time}[/] | [yellow bold]Warning[/] | " + msg + console.print(msg, highlight=False, no_wrap=False) def set_verbose( diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index b24fcb99c..e108d6edc 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -34,7 +34,7 @@ from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.v1_0.runtime.flows import compute_context from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0 -from nemoguardrails.colang.v2_x.runtime.flows import Action, State +from nemoguardrails.colang.v2_x.runtime.flows import Action, ActionEvent, State from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x from nemoguardrails.colang.v2_x.runtime.serialization import ( json_to_state, @@ -732,24 +732,27 @@ async def generate_async( start_action_match = re.match(r"Start(.*Action)", event["type"]) if start_action_match: - action_name = start_action_match[1] - # TODO: is there an elegant way to extract just the arguments? - arguments = { - k: v - for k, v in event.items() - if k != "type" - and k != "uid" - and k != "event_created_at" - and k != "source_uid" - and k != "action_uid" - } - response_tool_calls.append( - { - "id": event["action_uid"], - "type": "function", - "function": {"name": action_name, "arguments": arguments}, - } - ) + action_event = ActionEvent.from_umim_event(event) + action = Action.from_event(action_event) + if action: + # TODO (schuellc): Check why we need this? + # Also it seems we need to exclude the following actions + if action.name in [ + "UtteranceBotAction", + "GestureBotAction", + "PostureBotAction", + ]: + continue + response_tool_calls.append( + { + "id": action.uid, + "type": "function", + "function": { + "name": action.name, + "arguments": action.start_event_arguments, + }, + } + ) elif event["type"] == "UtteranceBotActionFinished": responses.append(event["final_script"]) diff --git a/tests/test_configs/with_custom_async_action_events_v2_x/actions.py b/tests/test_configs/with_custom_async_action_events_v2_x/actions.py new file mode 100644 index 000000000..be83f7f43 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_events_v2_x/actions.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from nemoguardrails.actions import action +from nemoguardrails.colang.v2_x.runtime.runtime import ActionEventHandler + + +@action(name="CustomAsyncTest1Action", is_system_action=True, execute_async=True) +async def custom_async_test1(event_handler: ActionEventHandler): + for i in range(1, 3): + await asyncio.sleep(1) + event_handler.send_action_updated_event("Value", {"number": i}) + await asyncio.sleep(1) + event_handler.send_event("CustomEventA", {"value": "A"}) + events = await event_handler.wait_for_events("CustomEventB") + event_handler.send_event("CustomEventResponse", {"value": events[0]["value"]}) + await event_handler.wait_for_events("CustomEventC") + await asyncio.sleep(3) + # raise Exception("Python action exception!") + + +@action(name="CustomAsyncTest2Action", is_system_action=True, execute_async=True) +async def custom_async_test2(event_handler: ActionEventHandler): + await event_handler.wait_for_events("CustomEventResponse") + await asyncio.sleep(3) + event_handler.send_event("CustomEventC", {"value": "C"}) diff --git a/tests/test_configs/with_custom_async_action_events_v2_x/config.co b/tests/test_configs/with_custom_async_action_events_v2_x/config.co new file mode 100644 index 000000000..21b662125 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_events_v2_x/config.co @@ -0,0 +1,16 @@ +flow main + match UtteranceUserAction.Finished(final_transcript="start") + start CustomAsyncTest1Action() as $action1_ref + start CustomAsyncTest2Action() + match $action1_ref.Started() + while True: + when $action1_ref.ValueUpdated() as $ref: + start UtteranceBotAction(script="Value: {$ref.number}") + or when CustomEventA() as $ref: + start UtteranceBotAction(script="Value: {$ref.value}") + send CustomEventB(value="B") + match CustomEventC(value="C") + start UtteranceBotAction(script="Check") + or when $action1_ref.Finished() as $event: + start UtteranceBotAction(script="End") + break diff --git a/tests/test_configs/with_custom_async_action_events_v2_x/config.yml b/tests/test_configs/with_custom_async_action_events_v2_x/config.yml new file mode 100644 index 000000000..ae6347623 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_events_v2_x/config.yml @@ -0,0 +1 @@ +colang_version: "2.x" diff --git a/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py b/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py new file mode 100644 index 000000000..868592317 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from nemoguardrails.actions import action +from nemoguardrails.colang.v2_x.runtime.runtime import ActionEventHandler + + +@action(name="Test1Action", is_system_action=True, execute_async=True) +async def test1(event_handler: ActionEventHandler): + event = None + value = None + while event is None: + event = await event_handler.wait_for_change_action_event() + if event: + value = event.get("volume", None) + if value: + break + else: + event = None + await asyncio.sleep(1) + event_handler.send_action_updated_event("Volume", {"value": value}) + + +@action(name="Test2Action", is_system_action=True, execute_async=True) +async def test2(event_handler: ActionEventHandler): + await event_handler.wait_for_events("NeverHappeningEver") + + +@action(name="Test3Action", is_system_action=True, execute_async=True) +async def test3(event_handler: ActionEventHandler): + raise Exception("Issue occurred!") diff --git a/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co new file mode 100644 index 000000000..5989e71ab --- /dev/null +++ b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co @@ -0,0 +1,26 @@ +flow main + match UtteranceUserAction.Finished(final_transcript="start") + + # Normal action lifetime + start Test1Action() as $action1_ref + match $action1_ref.Started() + start UtteranceBotAction(script="Started") + send $action1_ref.Change(direction="top") + send $action1_ref.Change(volume=10) + match Test1ActionVolumeUpdated() as $ref + start UtteranceBotAction(script="Volume: {$ref.value}") + match $action1_ref.Finished() as $ref + start UtteranceBotAction(script="Result: {$ref.is_success}") + + # Action gets canceled + start Test2Action() as $action2_ref + match $action2_ref.Started() + send StopTest2Action(action_uid=$action2_ref.uid) + match $action2_ref.Finished() as $ref + start UtteranceBotAction(script="Result: {$ref.was_stopped}") + + # Action raises Exception + start Test3Action() as $action3_ref + match $action3_ref.Started() + match $action3_ref.Finished() as $ref + start UtteranceBotAction(script="Result: {$ref.failure_reason}") diff --git a/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.yml b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.yml new file mode 100644 index 000000000..ae6347623 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.yml @@ -0,0 +1 @@ +colang_version: "2.x" diff --git a/tests/test_embeddings_only_user_messages.py b/tests/test_embeddings_only_user_messages.py index 9dd428755..6dfc25ef3 100644 --- a/tests/test_embeddings_only_user_messages.py +++ b/tests/test_embeddings_only_user_messages.py @@ -191,7 +191,7 @@ def test_examples_included_in_prompts_2(colang_2_config): colang_2_config, llm_completions=[ " user express greeting", - ' bot respond to uknown intent "Hello is there anything else" ', + ' bot respond to unknown intent\nbot action: bot say "Hello is there anything else"', ], ) @@ -214,7 +214,7 @@ def test_no_llm_calls_embedding_only(colang_2_config): colang_2_config, llm_completions=[ " user express greeting", - ' bot respond to uknown intent "Hello is there anything else" ', + ' bot respond to unknown intent "Hello is there anything else" ', ], ) diff --git a/tests/test_retrieve_relevant_chunks.py b/tests/test_retrieve_relevant_chunks.py index 7d1044661..f4e110e2e 100644 --- a/tests/test_retrieve_relevant_chunks.py +++ b/tests/test_retrieve_relevant_chunks.py @@ -29,12 +29,12 @@ activate llm continuation flow user express greeting - user said "hello" - or user said "hi" - or user said "how are you" + user said "hello" + or user said "hi" + or user said "how are you" flow bot express greeting - bot say "Hey!" + bot say "Hey!" flow greeting user express greeting @@ -58,7 +58,7 @@ def test_relevant_chunk_inserted_in_prompt(): config, llm_completions=[ " user express greeting", - ' bot respond to aditional context\nbot action: "Hello is there anything else" ', + ' bot respond to additional context\nbot action: bot say "Hello is there anything else" ', ], ) @@ -85,7 +85,7 @@ def test_relevant_chunk_inserted_in_prompt_no_kb(): config, llm_completions=[ " user express greeting", - ' bot respond to aditional context\nbot action: "Hello is there anything else" ', + ' bot respond to aditional context\nbot action: bot say "Hello is there anything else" ', ], ) rails = chat.app diff --git a/tests/v2_x/chat.py b/tests/v2_x/chat.py index 7cdc91d15..584d2f5cb 100644 --- a/tests/v2_x/chat.py +++ b/tests/v2_x/chat.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional import nemoguardrails.rails.llm.llmrails -from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails import LLMRails from nemoguardrails.cli.chat import extract_scene_text_content, parse_events_inputs from nemoguardrails.colang.v2_x.runtime.flows import State from nemoguardrails.utils import new_event_dict, new_uuid @@ -44,7 +44,7 @@ class ChatInterface: def __init__(self, rails_app: LLMRails): self.chat_state = ChatState() self.rails_app = rails_app - self.input_queue = asyncio.Queue() + self.input_queue: asyncio.Queue = asyncio.Queue() self.loop = asyncio.get_event_loop() asyncio.create_task(self.run()) @@ -322,6 +322,9 @@ async def run(self): self.chat_state.input_events = [] else: self.chat_state.waiting_user_input = True + # NOTE: We should never disable the user input since we can have + # async Python actions running in parallel + # TODO: Check if disabling causes race conditions await self.enable_input.wait() user_message = "" diff --git a/tests/v2_x/test_run_actions.py b/tests/v2_x/test_run_actions.py index 914c56d61..49c5d2ecf 100644 --- a/tests/v2_x/test_run_actions.py +++ b/tests/v2_x/test_run_actions.py @@ -14,11 +14,14 @@ # limitations under the License. import logging +import os from rich.logging import RichHandler from nemoguardrails import RailsConfig +from nemoguardrails.utils import get_or_create_event_loop from tests.utils import TestChat +from tests.v2_x.utils import compare_interaction_with_test_script FORMAT = "%(message)s" logging.basicConfig( @@ -28,8 +31,10 @@ handlers=[RichHandler(markup=True)], ) +CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), "../test_configs") -def test_1(): + +def test_basic_statement(): config = RailsConfig.from_content( colang_content=""" flow user express greeting @@ -62,7 +67,7 @@ async def fetch_name(): chat << "Hello world!" -def test_2(): +def test_short_return_statement(): config = RailsConfig.from_content( colang_content=""" flow user express greeting @@ -95,7 +100,7 @@ async def fetch_name(): chat << "John" -def test_3(): +def test_long_return_statement(): config = RailsConfig.from_content( colang_content=""" flow bot say $text @@ -129,5 +134,68 @@ async def fetch_dictionary(): chat << "I couldn't find any items matching your request!" +def test_custom_action(): + # This config just imports another one, to check that actions are correctly + # loaded. + config = RailsConfig.from_path( + os.path.join(CONFIGS_FOLDER, "with_custom_action_v2_x") + ) + + chat = TestChat( + config, + llm_completions=[], + ) + + chat >> "start" + chat << "8" + + +def test_custom_action_async_events(): + path = os.path.join(CONFIGS_FOLDER, "with_custom_async_action_events_v2_x") + test_script = """ + > start + Value: 1 + Value: 2 + Event: CustomEventA + Value: A + Event: CustomEventB + Event: CustomEventResponse + Event: CustomEventC + Check + End + Event: StopUtteranceBotAction + """ + + loop = get_or_create_event_loop() + result = loop.run_until_complete( + compare_interaction_with_test_script(test_script, 10.0, colang_path=path) + ) + + assert result is None, result + + +def test_custom_async_action_lifetime(): + path = os.path.join(CONFIGS_FOLDER, "with_custom_async_action_lifetime_v2_x") + test_script = """ + > start + Started + Volume: 10 + Result: True + Result: True + Result: Local action finished with an exception! + Event: StopUtteranceBotAction + Event: StopUtteranceBotAction + Event: StopUtteranceBotAction + Event: StopUtteranceBotAction + """ + + loop = get_or_create_event_loop() + result = loop.run_until_complete( + compare_interaction_with_test_script(test_script, 6.0, colang_path=path) + ) + + assert result is None, result + + if __name__ == "__main__": - test_3() + test_custom_async_action_lifetime() diff --git a/tests/v2_x/utils.py b/tests/v2_x/utils.py new file mode 100644 index 000000000..213536e40 --- /dev/null +++ b/tests/v2_x/utils.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +# TODO: This is a copy of docs/colang-2/examples/utils.py, we should unify +# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + +import asyncio +from typing import Optional + +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.llmrails import LLMRails +from tests.utils import FakeLLM +from tests.v2_x.chat import ChatInterface + +YAML_CONFIG = """ +colang_version: "2.x" +""" + + +async def run_chat_interface_based_on_test_script( + test_script: str, + wait_time_s: float, + colang: Optional[str] = None, + colang_path: Optional[str] = None, + llm_responses: Optional[list] = None, +) -> str: + rails_config: RailsConfig + if colang: + rails_config = RailsConfig.from_content( + colang_content=colang, + yaml_content=YAML_CONFIG, + ) + elif colang_path: + rails_config = RailsConfig.from_path(colang_path) + + interaction_log = [] + + if llm_responses: + llm = FakeLLM(responses=llm_responses) + rails_app = LLMRails(rails_config, verbose=True, llm=llm) + else: + rails_app = LLMRails(rails_config, verbose=True) + + chat = ChatInterface(rails_app) + + lines = test_script.split("\n") + for line in lines: + line = line.strip() + if line.startswith("#"): + continue + if line.startswith(">"): + interaction_log.append(line) + user_input = line.replace("> ", "") + print(f"sending '{user_input}' to process") + response = await chat.process(user_input, wait_time_s) + interaction_log.append(response) + + chat.should_terminate = True + await asyncio.sleep(0.5) + + return "\n".join(interaction_log) + + +def cleanup(content): + output = [] + lines = content.split("\n") + for line in lines: + if len(line.strip()) == 0: + continue + if line.strip() == ">": + continue + if line.startswith("#"): + continue + if "Starting the chat" in line: + continue + + output.append(line.strip()) + + return "\n".join(output) + + +async def compare_interaction_with_test_script( + test_script: str, + wait_time_s: float = 1.0, + colang: Optional[str] = None, + colang_path: Optional[str] = None, + llm_responses: Optional[list] = None, +) -> Optional[str]: + result = await run_chat_interface_based_on_test_script( + test_script, wait_time_s, colang, colang_path, llm_responses=llm_responses + ) + clean_test_script = cleanup(test_script) + clean_result = cleanup(result) + if clean_test_script == clean_result: + return None + + return f"\n----\n{clean_result}\n----\n\ndoes not match test script\n\n----\n{clean_test_script}\n----"