From 18e9ddd20d681f6455c0e26b53e624defaf510f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Tue, 10 Dec 2024 19:57:02 +0100 Subject: [PATCH 01/14] Add Python action async action generation support --- nemoguardrails/actions/action_dispatcher.py | 35 +++- nemoguardrails/colang/runtime.py | 5 + nemoguardrails/colang/v2_x/runtime/flows.py | 59 +++++- nemoguardrails/colang/v2_x/runtime/runtime.py | 183 +++++++++++------- .../colang/v2_x/runtime/statemachine.py | 19 +- 5 files changed, 200 insertions(+), 101 deletions(-) diff --git a/nemoguardrails/actions/action_dispatcher.py b/nemoguardrails/actions/action_dispatcher.py index 67eef91cd..1a8a179d7 100644 --- a/nemoguardrails/actions/action_dispatcher.py +++ b/nemoguardrails/actions/action_dispatcher.py @@ -15,29 +15,50 @@ """Module for the calling proper action endpoints based on events received at action server endpoint""" +import asyncio import importlib.util import inspect import logging import os from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast from langchain.chains.base import Chain from langchain_core.runnables import Runnable from nemoguardrails import utils from nemoguardrails.actions.llm.utils import LLMCallException +from nemoguardrails.colang.v2_x.runtime.flows import Action from nemoguardrails.logging.callbacks import logging_callbacks log = logging.getLogger(__name__) +class ActionEventGenerator: + """Generator to emit event from async Python actions.""" + + def __init__(self, action: Action, action_event_queue: asyncio.Queue[dict]): + # The relevant action + self._action = action + + # Contains reference to the async action event queue + self._action_events_queue = action_event_queue + + async def send_action_update_event(self, event_name: str, args: dict) -> None: + """Send a ActionUpdated event.""" + action_event = self._action.updated_event( + {"event_parameter_name": event_name, **args} + ) + await self._action_events_queue.put(action_event.to_umim_event()) + + class ActionDispatcher: def __init__( self, load_all_actions: bool = True, config_path: Optional[str] = None, import_paths: Optional[List[str]] = None, + action_event_queue: Optional[asyncio.Queue[dict]] = None, ): """ Initializes an actions dispatcher. @@ -51,7 +72,11 @@ def __init__( """ log.info("Initializing action dispatcher") - self._registered_actions = {} + # Dictionary with all registered actions + self._registered_actions: dict = {} + + # Contains generated events form async actions + self._async_action_events: Optional[asyncio.Queue[dict]] = action_event_queue if load_all_actions: # TODO: check for better way to find actions dir path or use constants.py @@ -87,7 +112,7 @@ def __init__( for import_path in import_paths: self.load_actions_from_path(Path(import_path.strip())) - log.info(f"Registered Actions :: {sorted(self._registered_actions.keys())}") + log.info("Registered Actions :: %s", sorted(self._registered_actions.keys())) log.info("Action dispatcher initialized") @property @@ -195,7 +220,7 @@ async def execute_action( action_name = self._normalize_action_name(action_name) if action_name in self._registered_actions: - log.info(f"Executing registered action: {action_name}") + log.info("Executing registered action: %s", action_name) fn = self._registered_actions.get(action_name, None) # Actions that are registered as classes are initialized lazy, when @@ -214,7 +239,7 @@ async def execute_action( result = await result else: log.warning( - f"Synchronous action `{action_name}` has been called." + "Synchronous action `%s` has been called.", action_name ) elif isinstance(fn, Chain): diff --git a/nemoguardrails/colang/runtime.py b/nemoguardrails/colang/runtime.py index ba61eaaf5..71db5915f 100644 --- a/nemoguardrails/colang/runtime.py +++ b/nemoguardrails/colang/runtime.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging from abc import abstractmethod from typing import Any, Callable, List, Optional, Tuple @@ -31,10 +32,14 @@ def __init__(self, config: RailsConfig, verbose: bool = False): self.config = config self.verbose = verbose + # Contains generated events form async Python actions that should be processed + self._async_action_events: asyncio.Queue[dict] = asyncio.Queue() + # Register the actions with the dispatcher. self.action_dispatcher = ActionDispatcher( config_path=config.config_path, import_paths=list(config.imported_paths.values()), + action_event_queue=self._async_action_events, ) # The list of additional parameters that can be passed to the actions. diff --git a/nemoguardrails/colang/v2_x/runtime/flows.py b/nemoguardrails/colang/v2_x/runtime/flows.py index 053f43e65..29f3121bf 100644 --- a/nemoguardrails/colang/v2_x/runtime/flows.py +++ b/nemoguardrails/colang/v2_x/runtime/flows.py @@ -18,12 +18,24 @@ from __future__ import annotations import logging +import os import time from collections import deque from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Callable, Deque, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + ClassVar, + Deque, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) from dataclasses_json import dataclass_json @@ -33,7 +45,8 @@ FlowReturnMemberDef, ) from nemoguardrails.colang.v2_x.runtime.errors import ColangSyntaxError -from nemoguardrails.utils import new_readable_uuid, new_uuid +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.utils import new_event_dict, new_readable_uuid, new_uuid log = logging.getLogger(__name__) @@ -108,6 +121,15 @@ def from_umim_event(cls, event: dict) -> Event: ) return new_event + def to_umim_event(self, config: Optional[RailsConfig] = None) -> Dict[str, Any]: + """Return a umim event dictionary.""" + new_event_args = dict(self.arguments) + new_event_args.setdefault( + "source_uid", + config.event_source_uid if config else "NeMoGuardrails-Colang-2.x", + ) + return new_event_dict(self.name, **new_event_args) + # Expose all event parameters as attributes of the event def __getattr__(self, name): if ( @@ -150,6 +172,19 @@ def from_umim_event(cls, event: dict) -> ActionEvent: new_event.action_uid = event["action_uid"] return new_event + def to_umim_event(self) -> Dict[str, Any]: + """Return a umim event dictionary.""" + new_event_args = dict(self.arguments) + new_event_args["source_uid"] = ( + os.getenv("SOURCE_ID", None) or "NeMoGuardrails-Colang-2.x" + ) + if self.action_uid: + return new_event_dict( + self.name, action_uid=self.action_uid, **new_event_args + ) + else: + return new_event_dict(self.name, **new_event_args) + class ActionStatus(Enum): """The status of an action.""" @@ -176,6 +211,18 @@ class Action: "Stop": "stop_event", } + # List of umim specific parameters + _umim_parameters: ClassVar[List[str]] = [ + "type", + "uid", + "event_created_at", + "source_uid", + "action_uid", + "action_info_modality", + "action_info_modality_policy", + "action_finished_at", + ] + @classmethod def from_event(cls, event: ActionEvent) -> Optional[Action]: """Returns the action if event name conforms with UMIM convention.""" @@ -189,6 +236,12 @@ def from_event(cls, event: ActionEvent) -> Optional[Action]: if name != "Finished" else ActionStatus.FINISHED ) + if name == "Start": + action.start_event_arguments = { + key: event.arguments[key] + for key in event.arguments + if key not in cls._umim_parameters + } return action return None @@ -355,7 +408,7 @@ class FlowConfig: id: str # The sequence of elements that compose the flow. - elements: List[ElementType] + elements: Sequence[ElementType] # The flow parameters parameters: List[FlowParamDef] diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index 20044b8a6..bccfc6c1a 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -16,13 +16,14 @@ import inspect import logging import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast from urllib.parse import urljoin import aiohttp import langchain from langchain.chains.base import Chain +from nemoguardrails.actions.action_dispatcher import ActionEventGenerator from nemoguardrails.actions.actions import ActionResult from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.runtime import Runtime @@ -32,7 +33,12 @@ ColangRuntimeError, ColangSyntaxError, ) -from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus +from nemoguardrails.colang.v2_x.runtime.flows import ( + Action, + ActionEvent, + Event, + FlowStatus, +) from nemoguardrails.colang.v2_x.runtime.statemachine import ( FlowConfig, InternalEvent, @@ -147,9 +153,13 @@ async def _remove_flows_action(self, state: "State", **args: dict) -> None: def _init_flow_configs(self) -> None: """Initializes the flow configs based on the config.""" - self.flow_configs = create_flow_configs_from_flow_list(self.config.flows) + self.flow_configs = create_flow_configs_from_flow_list( + cast(List[Flow], self.config.flows) + ) - async def generate_events(self, events: List[dict]) -> List[dict]: + async def generate_events( + self, events: List[dict], processing_log: Optional[List[dict]] = None + ) -> List[dict]: raise NotImplementedError("Stateless API not supported for Colang 2.x, yet.") @staticmethod @@ -173,24 +183,22 @@ def _internal_error_action_result(message: str) -> ActionResult: async def _process_start_action( self, - action_name: str, - action_params: dict, + action: Action, context: dict, - events: List[dict], state: "State", ) -> Tuple[Any, List[dict], dict]: """Starts the specified action, waits for it to finish and posts back the result.""" - fn = self.action_dispatcher.get_action(action_name) + fn = self.action_dispatcher.get_action(action.name) # TODO: check action is available in action server if fn is None: result = self._internal_error_action_result( - f"Action '{action_name}' not found." + f"Action '{action.name}' not found." ) else: # We pass all the parameters that are passed explicitly to the action. - kwargs = {**action_params} + kwargs = {**action.start_event_arguments} action_meta = getattr(fn, "action_meta", {}) @@ -228,13 +236,21 @@ async def _process_start_action( and action_type != "chain" ): result, status = await self._get_action_resp( - action_meta, action_name, kwargs + action_meta, action.name, kwargs ) else: # We don't send these to the actions server; # TODO: determine if we should if "events" in parameters: - kwargs["events"] = events + kwargs["events"] = state.last_events + + if "event_generator" in parameters: + kwargs["event_generator"] = ActionEventGenerator( + action, self._async_action_events + ) + + if "action" in parameters: + kwargs["action"] = action if "context" in parameters: kwargs["context"] = context @@ -255,13 +271,13 @@ async def _process_start_action( if ( "llm" in kwargs - and f"{action_name}_llm" in self.registered_action_params + and f"{action.name}_llm" in self.registered_action_params ): - kwargs["llm"] = self.registered_action_params[f"{action_name}_llm"] + kwargs["llm"] = self.registered_action_params[f"{action.name}_llm"] - log.info("Running action :: %s", action_name) + log.info("Running action :: %s", action.name) result, status = await self.action_dispatcher.execute_action( - action_name, kwargs + action.name, kwargs ) # If the action execution failed, we return a hardcoded message @@ -271,7 +287,7 @@ async def _process_start_action( "I'm sorry, an internal error has occurred." ) - return_value = result + return_value: Any = result return_events: List[dict] = [] context_updates: dict = {} @@ -317,10 +333,10 @@ async def _get_action_resp( f"Got status code {resp.status} while getting response from {action_name}" ) - resp = await resp.json() + json_resp = await resp.json() result, status = ( - resp.get("result", result), - resp.get("status", status), + json_resp.get("result", result), + json_resp.get("status", status), ) except Exception as e: log.info( @@ -337,20 +353,42 @@ async def _get_action_resp( return result, status @staticmethod - def _get_action_finished_event(result: dict, **kwargs) -> Dict[str, Any]: - """Helper to return the ActionFinished event from the result of running a local action.""" - return new_event_dict( - f"{result['action_name']}Finished", - action_uid=result["start_action_event"]["action_uid"], - action_name=result["action_name"], - status="success", - is_success=True, - return_value=result["return_value"], - events=result["new_events"], - **kwargs, - # is_system_action=action_meta.get("is_system_action", False), + def _get_action_finished_event( + action: Action, + **kwargs, + ) -> Dict[str, Any]: + """Helper to augment the ActionFinished event with additional data.""" + if "return_value" not in kwargs: + kwargs["return_value"] = None + if "events" not in kwargs: + kwargs["events"] = [] + event = action.finished_event( + { + "action_name": action.name, + "status": "success", + "is_success": True, + **kwargs, + } ) + return event.to_umim_event() + + async def _get_async_action_events(self) -> List[dict]: + events = [] + while True: + try: + # Attempt to get an item from the queue without waiting + event = self._async_action_events.get_nowait() + + # assert isinstance(event, dict), "Python action events must be a dictionary!" + # {'type': 'CheckLocalAsync', 'uid': '486a5628-843b-4ed3-b8b3-315ef01c9a13', 'event_created_at': '2024-12-09T14:55:50.847199+00:00', 'source_uid': 'NeMoGuardrails'} + + events.append(event) + except asyncio.QueueEmpty: + # Break the loop if the queue is empty + break + return events + async def _get_async_actions_finished_events( self, main_flow_uid: str ) -> Tuple[List[dict], int]: @@ -389,7 +427,7 @@ async def _get_async_actions_finished_events( self.async_actions[main_flow_uid].remove(finished_task) # We need to create the corresponding action finished event - action_finished_event = self._get_action_finished_event(result) + action_finished_event = self._get_action_finished_event(**result) action_finished_events.append(action_finished_event) return action_finished_events, len(pending) @@ -429,10 +467,12 @@ async def process_events( state. """ - output_events = [] - input_events: List[Union[dict, InternalEvent]] = events.copy() + output_events: List[Dict[str, Any]] = [] + input_events: List[Union[dict, InternalEvent]] = [] local_running_actions: List[asyncio.Task[dict]] = [] + input_events.extend(events) + if state is None or state == {}: state = State( flow_states={}, flow_configs=self.flow_configs, rails_config=self.config @@ -475,6 +515,9 @@ async def process_events( input_events.insert(0, input_event) idx += 1 + # Check if we have new async action events to add + input_events.extend(await self._get_async_action_events()) + # Check if we have new finished async local action events to add ( local_action_finished_events, @@ -493,7 +536,7 @@ async def process_events( events_counter += 1 if events_counter > self.max_events: log.critical( - f"Maximum number of events reached ({events_counter})!" + "Maximum number of events reached (%s)!", events_counter ) return output_events, state @@ -541,35 +584,28 @@ async def process_events( # We need to check if we need to run a locally registered action start_action_match = re.match(r"Start(.*Action)", out_event["type"]) if start_action_match: - action_name = start_action_match[1] + action_event = ActionEvent.from_umim_event(out_event) + action = Action.from_event(action_event) + assert action # If it's an instant action, we finish it right away. - if instant_actions and action_name in instant_actions: - finished_event_data: dict = { - "action_name": action_name, - "start_action_event": out_event, - "return_value": None, - "new_events": [], - } - - # TODO: figure out a generic way of creating a compliant - # ...ActionFinished event - extra = {} - if action_name == "UtteranceBotAction": + if instant_actions and action.name in instant_actions: + extra = {"action": action} + if action.name == "UtteranceBotAction": extra["final_script"] = out_event["script"] action_finished_event = self._get_action_finished_event( - finished_event_data, **extra + **extra ) # We send the completion of the action as an output event # and continue processing it. output_events.append(action_finished_event) - input_events.append(action_finished_event) + new_outgoing_events.append(action_finished_event) - elif self.action_dispatcher.has_registered(action_name): + elif self.action_dispatcher.has_registered(action.name): # In this case we need to start the action locally - action_fn = self.action_dispatcher.get_action(action_name) + action_fn = self.action_dispatcher.get_action(action.name) execute_async = getattr(action_fn, "action_meta", {}).get( "execute_async", False ) @@ -577,13 +613,19 @@ async def process_events( # Start the local action local_action = asyncio.create_task( self._run_action( - action_name, - start_action_event=out_event, - events_history=state.last_events, + action, state=state, ) ) + # Generate *ActionStarted event + action_started_event = action.started_event({}) + action_started_umim_event = ( + action_started_event.to_umim_event() + ) + output_events.append(action_started_umim_event) + new_outgoing_events.append(action_started_umim_event) + # If the function is not async, or async execution is disabled # we execute the actions as a local action. # Also, if we're running this in blocking mode, we add all local @@ -604,6 +646,9 @@ async def process_events( else: output_events.append(out_event) + # Check if we have new async action events to add + new_outgoing_events.extend(await self._get_async_action_events()) + # Check if we have new finished async local action events to add ( new_local_action_finished_events, @@ -624,6 +669,7 @@ async def process_events( # If we have any local running actions, we need to wait for at least one # of them to finish. + # TODO: Check why we should wait for one to finish?! if local_running_actions: log.info( "Waiting for %d local actions to finish.", @@ -639,7 +685,7 @@ async def process_events( result = finished_task.result() # We need to create the corresponding action finished event - action_finished_event = self._get_action_finished_event(result) + action_finished_event = self._get_action_finished_event(**result) input_events.append(action_finished_event) if return_local_async_action_count: @@ -663,42 +709,29 @@ async def process_events( async def _run_action( self, - action_name: str, - start_action_event: dict, - events_history: List[Union[dict, Event]], + action: Action, state: "State", ) -> dict: """Runs the locally registered action. Args - action_name: The name of the action to be executed. - start_action_event: The event that triggered the action. - events_history: The recent history of events that led to the action being triggered. + action: The action to be executed. + state: The state of the runtime. """ - # NOTE: To extract the actual parameters that should be passed to the local action, - # we ignore all the keys from "an empty event" of the same type. - ignore_keys = new_event_dict(start_action_event["type"]).keys() - action_params = { - k: v for k, v in start_action_event.items() if k not in ignore_keys - } - return_value, new_events, context_updates = await self._process_start_action( - action_name, - action_params=action_params, + action, context=state.context, - events=events_history, state=state, ) state.context.update(context_updates) return { - "action_name": action_name, + "action": action, "return_value": return_value, "new_events": new_events, "context_updates": context_updates, - "start_action_event": start_action_event, } diff --git a/nemoguardrails/colang/v2_x/runtime/statemachine.py b/nemoguardrails/colang/v2_x/runtime/statemachine.py index b864fc086..d6f84c18e 100644 --- a/nemoguardrails/colang/v2_x/runtime/statemachine.py +++ b/nemoguardrails/colang/v2_x/runtime/statemachine.py @@ -1830,7 +1830,7 @@ def _is_done_flow(flow_state: FlowState) -> bool: def _generate_umim_event(state: State, event: Event) -> Dict[str, Any]: - umim_event = create_umim_event(event, event.arguments, state.rails_config) + umim_event = event.to_umim_event(state.rails_config) state.outgoing_events.append(umim_event) log.info("[bold violet]<- Action[/]: %s", event) @@ -2386,23 +2386,6 @@ def create_internal_event( return event -def create_umim_event( - event: Event, event_args: Dict[str, Any], config: Optional[RailsConfig] -) -> Dict[str, Any]: - """Returns an outgoing UMIM event for the provided action data""" - new_event_args = dict(event_args) - new_event_args.setdefault( - "source_uid", config.event_source_uid if config else "NeMoGuardrails-Colang-2.x" - ) - if isinstance(event, ActionEvent) and event.action_uid is not None: - if "action_uid" in new_event_args: - event.action_uid = new_event_args["action_uid"] - del new_event_args["action_uid"] - return new_event_dict(event.name, action_uid=event.action_uid, **new_event_args) - else: - return new_event_dict(event.name, **new_event_args) - - def _get_eval_context(state: State, flow_state: FlowState) -> dict: context = flow_state.context.copy() # Link global variables From 9164e646f7b7b0fbe1b7d9392176a7df5e0ac2b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Tue, 10 Dec 2024 19:57:56 +0100 Subject: [PATCH 02/14] Work on action test --- .../with_custom_action_v2_x/actions.py | 16 ++++++++++++++-- .../with_custom_action_v2_x/config.co | 7 ++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_configs/with_custom_action_v2_x/actions.py b/tests/test_configs/with_custom_action_v2_x/actions.py index 092599c18..ae04f9cec 100644 --- a/tests/test_configs/with_custom_action_v2_x/actions.py +++ b/tests/test_configs/with_custom_action_v2_x/actions.py @@ -13,9 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +from typing import Optional + from nemoguardrails.actions import action +from nemoguardrails.actions.action_dispatcher import ActionEventGenerator -@action(name="CustomTestAction") -async def custom_test(context: dict, param: int): +@action(name="CustomTestAction", is_system_action=True, execute_async=True) +async def custom_test(context: dict, param: int, event_generator: ActionEventGenerator): + for i in range(1, 5): + await asyncio.sleep(5) + await event_generator.send_action_update_event("Test", {f"value {i}": 10}) return param + context["value"] + + +@action(name="CustomActionWithUpdateEventsAction") +async def custom_action_with_update_events(context: Optional[dict] = None, **kwargs): + return True diff --git a/tests/test_configs/with_custom_action_v2_x/config.co b/tests/test_configs/with_custom_action_v2_x/config.co index 290ee977e..b3fd1f9dc 100644 --- a/tests/test_configs/with_custom_action_v2_x/config.co +++ b/tests/test_configs/with_custom_action_v2_x/config.co @@ -2,5 +2,10 @@ flow main global $value $value = 3 match UtteranceUserAction.Finished(final_transcript="start") - $sum = await CustomTestAction(param=5) + start CustomTestAction(param=5) as $action_ref + print $action_ref.uid + match $action_ref.Finished() as $event + $sum = $event.return_value + print $sum + start UtteranceBotAction(script="{$sum}") From 7165c7332c468b5925890a9bc6c3b68aea13706b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 11 Dec 2024 09:48:51 +0100 Subject: [PATCH 03/14] Fix event_source_uid handling --- nemoguardrails/actions/action_dispatcher.py | 17 ++++++++++++++--- nemoguardrails/colang/v2_x/runtime/flows.py | 13 ++++++------- nemoguardrails/colang/v2_x/runtime/runtime.py | 19 +++++++++++++------ .../colang/v2_x/runtime/statemachine.py | 4 +++- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/nemoguardrails/actions/action_dispatcher.py b/nemoguardrails/actions/action_dispatcher.py index 1a8a179d7..c259dafef 100644 --- a/nemoguardrails/actions/action_dispatcher.py +++ b/nemoguardrails/actions/action_dispatcher.py @@ -21,7 +21,7 @@ import logging import os from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union from langchain.chains.base import Chain from langchain_core.runnables import Runnable @@ -30,6 +30,7 @@ from nemoguardrails.actions.llm.utils import LLMCallException from nemoguardrails.colang.v2_x.runtime.flows import Action from nemoguardrails.logging.callbacks import logging_callbacks +from nemoguardrails.rails.llm.config import RailsConfig log = logging.getLogger(__name__) @@ -37,7 +38,15 @@ class ActionEventGenerator: """Generator to emit event from async Python actions.""" - def __init__(self, action: Action, action_event_queue: asyncio.Queue[dict]): + def __init__( + self, + config: RailsConfig, + action: Action, + action_event_queue: asyncio.Queue[dict], + ): + # The LLMRails config + self._config = config + # The relevant action self._action = action @@ -49,7 +58,9 @@ async def send_action_update_event(self, event_name: str, args: dict) -> None: action_event = self._action.updated_event( {"event_parameter_name": event_name, **args} ) - await self._action_events_queue.put(action_event.to_umim_event()) + await self._action_events_queue.put( + action_event.to_umim_event(self._config.event_source_uid) + ) class ActionDispatcher: diff --git a/nemoguardrails/colang/v2_x/runtime/flows.py b/nemoguardrails/colang/v2_x/runtime/flows.py index 29f3121bf..7a480859c 100644 --- a/nemoguardrails/colang/v2_x/runtime/flows.py +++ b/nemoguardrails/colang/v2_x/runtime/flows.py @@ -18,7 +18,6 @@ from __future__ import annotations import logging -import os import time from collections import deque from dataclasses import dataclass, field @@ -45,7 +44,6 @@ FlowReturnMemberDef, ) from nemoguardrails.colang.v2_x.runtime.errors import ColangSyntaxError -from nemoguardrails.rails.llm.config import RailsConfig from nemoguardrails.utils import new_event_dict, new_readable_uuid, new_uuid log = logging.getLogger(__name__) @@ -121,12 +119,12 @@ def from_umim_event(cls, event: dict) -> Event: ) return new_event - def to_umim_event(self, config: Optional[RailsConfig] = None) -> Dict[str, Any]: + def to_umim_event(self, event_source_uid: Optional[str] = None) -> Dict[str, Any]: """Return a umim event dictionary.""" new_event_args = dict(self.arguments) new_event_args.setdefault( "source_uid", - config.event_source_uid if config else "NeMoGuardrails-Colang-2.x", + event_source_uid if event_source_uid else "NeMoGuardrails-Colang-2.x", ) return new_event_dict(self.name, **new_event_args) @@ -172,11 +170,12 @@ def from_umim_event(cls, event: dict) -> ActionEvent: new_event.action_uid = event["action_uid"] return new_event - def to_umim_event(self) -> Dict[str, Any]: + def to_umim_event(self, event_source_uid: Optional[str] = None) -> Dict[str, Any]: """Return a umim event dictionary.""" new_event_args = dict(self.arguments) - new_event_args["source_uid"] = ( - os.getenv("SOURCE_ID", None) or "NeMoGuardrails-Colang-2.x" + new_event_args.setdefault( + "source_uid", + event_source_uid if event_source_uid else "NeMoGuardrails-Colang-2.x", ) if self.action_uid: return new_event_dict( diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index bccfc6c1a..e2a1d6894 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -246,7 +246,7 @@ async def _process_start_action( if "event_generator" in parameters: kwargs["event_generator"] = ActionEventGenerator( - action, self._async_action_events + self.config, action, self._async_action_events ) if "action" in parameters: @@ -354,6 +354,7 @@ async def _get_action_resp( @staticmethod def _get_action_finished_event( + rails_config: RailsConfig, action: Action, **kwargs, ) -> Dict[str, Any]: @@ -371,7 +372,7 @@ def _get_action_finished_event( } ) - return event.to_umim_event() + return event.to_umim_event(rails_config.event_source_uid) async def _get_async_action_events(self) -> List[dict]: events = [] @@ -427,7 +428,9 @@ async def _get_async_actions_finished_events( self.async_actions[main_flow_uid].remove(finished_task) # We need to create the corresponding action finished event - action_finished_event = self._get_action_finished_event(**result) + action_finished_event = self._get_action_finished_event( + self.config, **result + ) action_finished_events.append(action_finished_event) return action_finished_events, len(pending) @@ -595,7 +598,7 @@ async def process_events( extra["final_script"] = out_event["script"] action_finished_event = self._get_action_finished_event( - **extra + self.config, **extra ) # We send the completion of the action as an output event @@ -621,7 +624,9 @@ async def process_events( # Generate *ActionStarted event action_started_event = action.started_event({}) action_started_umim_event = ( - action_started_event.to_umim_event() + action_started_event.to_umim_event( + self.config.event_source_uid + ) ) output_events.append(action_started_umim_event) new_outgoing_events.append(action_started_umim_event) @@ -685,7 +690,9 @@ async def process_events( result = finished_task.result() # We need to create the corresponding action finished event - action_finished_event = self._get_action_finished_event(**result) + action_finished_event = self._get_action_finished_event( + self.config, **result + ) input_events.append(action_finished_event) if return_local_async_action_count: diff --git a/nemoguardrails/colang/v2_x/runtime/statemachine.py b/nemoguardrails/colang/v2_x/runtime/statemachine.py index d6f84c18e..a55f59e32 100644 --- a/nemoguardrails/colang/v2_x/runtime/statemachine.py +++ b/nemoguardrails/colang/v2_x/runtime/statemachine.py @@ -1830,7 +1830,9 @@ def _is_done_flow(flow_state: FlowState) -> bool: def _generate_umim_event(state: State, event: Event) -> Dict[str, Any]: - umim_event = event.to_umim_event(state.rails_config) + umim_event = event.to_umim_event( + state.rails_config.event_source_uid if state and state.rails_config else None + ) state.outgoing_events.append(umim_event) log.info("[bold violet]<- Action[/]: %s", event) From 580ffcd3530216391a227ba822c78b9b567585aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 11 Dec 2024 11:27:41 +0100 Subject: [PATCH 04/14] Simplify runtime loop --- nemoguardrails/colang/v2_x/runtime/runtime.py | 49 +++++++------------ 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index e2a1d6894..c74577c6d 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -524,7 +524,7 @@ async def process_events( # Check if we have new finished async local action events to add ( local_action_finished_events, - pending_local_async_action_counter, + pending_local_action_counter, ) = await self._get_async_actions_finished_events(main_flow_uid) input_events.extend(local_action_finished_events) local_action_finished_events = [] @@ -534,8 +534,9 @@ async def process_events( # we continue the processing. events_counter = 0 while input_events or local_running_actions: - new_outgoing_events = [] - for event in input_events: + while input_events: + event = input_events.pop(0) + events_counter += 1 if events_counter > self.max_events: log.critical( @@ -558,7 +559,7 @@ async def process_events( # Advance the state machine new_event: Optional[Union[dict, Event]] = event - while new_event is not None: + while new_event: try: run_to_completion(state, new_event) new_event = None @@ -573,13 +574,6 @@ async def process_events( ) await asyncio.sleep(0.001) - # If we have context updates after this event, we first add that. - # TODO: Check if this is still needed for e.g. stateless implementation - # if state.context_updates: - # output_events.append( - # new_event_dict("ContextUpdate", data=state.context_updates) - # ) - for out_event in state.outgoing_events: # We also record the out events in the recent history. state.last_events.append(out_event) @@ -603,8 +597,9 @@ async def process_events( # We send the completion of the action as an output event # and continue processing it. + # TODO: Why do we need an output event for that? It should only be an new input event output_events.append(action_finished_event) - new_outgoing_events.append(action_finished_event) + input_events.append(action_finished_event) elif self.action_dispatcher.has_registered(action.name): # In this case we need to start the action locally @@ -628,8 +623,8 @@ async def process_events( self.config.event_source_uid ) ) - output_events.append(action_started_umim_event) - new_outgoing_events.append(action_started_umim_event) + # output_events.append(action_started_umim_event) + input_events.append(action_started_umim_event) # If the function is not async, or async execution is disabled # we execute the actions as a local action. @@ -651,26 +646,18 @@ async def process_events( else: output_events.append(out_event) - # Check if we have new async action events to add - new_outgoing_events.extend(await self._get_async_action_events()) + # Add new async action events as new input events + input_events.extend(await self._get_async_action_events()) - # Check if we have new finished async local action events to add + # Add new finished async local action events as new input events ( - new_local_action_finished_events, - pending_local_async_action_counter, + new_action_finished_events, + pending_local_action_counter, ) = await self._get_async_actions_finished_events(main_flow_uid) - local_action_finished_events.extend(new_local_action_finished_events) - new_outgoing_events.extend(state.outgoing_events) - - input_events.clear() - - # If we have outgoing events we are also processing them as input events - if new_outgoing_events: - input_events.extend(new_outgoing_events) - continue + input_events.extend(new_action_finished_events) - input_events.extend(local_action_finished_events) - local_action_finished_events = [] + # Add generated events as new input events + input_events.extend(state.outgoing_events) # If we have any local running actions, we need to wait for at least one # of them to finish. @@ -703,7 +690,7 @@ async def process_events( ) output_events.append( new_event_dict( - "LocalAsyncCounter", counter=pending_local_async_action_counter + "LocalAsyncCounter", counter=pending_local_action_counter ) ) From 183f1fc9e9d95eb1080d90ae2daceb0a0c323e33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 11 Dec 2024 12:27:02 +0100 Subject: [PATCH 05/14] Improve ActionEventGenerator --- nemoguardrails/actions/action_dispatcher.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/nemoguardrails/actions/action_dispatcher.py b/nemoguardrails/actions/action_dispatcher.py index c259dafef..c6638fad9 100644 --- a/nemoguardrails/actions/action_dispatcher.py +++ b/nemoguardrails/actions/action_dispatcher.py @@ -28,7 +28,7 @@ from nemoguardrails import utils from nemoguardrails.actions.llm.utils import LLMCallException -from nemoguardrails.colang.v2_x.runtime.flows import Action +from nemoguardrails.colang.v2_x.runtime.flows import Action, Event from nemoguardrails.logging.callbacks import logging_callbacks from nemoguardrails.rails.llm.config import RailsConfig @@ -53,15 +53,22 @@ def __init__( # Contains reference to the async action event queue self._action_events_queue = action_event_queue - async def send_action_update_event(self, event_name: str, args: dict) -> None: - """Send a ActionUpdated event.""" + def send_action_updated_event(self, event_name: str, args: dict) -> None: + """Send an Action*Updated event.""" action_event = self._action.updated_event( {"event_parameter_name": event_name, **args} ) - await self._action_events_queue.put( + self._action_events_queue.put_nowait( action_event.to_umim_event(self._config.event_source_uid) ) + def send_raw_event(self, event_name: str, args: dict) -> None: + """Send any event.""" + event = Event(event_name, args) + self._action_events_queue.put_nowait( + event.to_umim_event(self._config.event_source_uid) + ) + class ActionDispatcher: def __init__( From 58269f65240e6624c9e2731a2a0cfd1b6198906f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 11 Dec 2024 14:55:49 +0100 Subject: [PATCH 06/14] Revert back custom async action test config --- .../with_custom_action_v2_x/actions.py | 16 ++-------------- .../with_custom_action_v2_x/config.co | 7 +------ 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/tests/test_configs/with_custom_action_v2_x/actions.py b/tests/test_configs/with_custom_action_v2_x/actions.py index ae04f9cec..092599c18 100644 --- a/tests/test_configs/with_custom_action_v2_x/actions.py +++ b/tests/test_configs/with_custom_action_v2_x/actions.py @@ -13,21 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from typing import Optional - from nemoguardrails.actions import action -from nemoguardrails.actions.action_dispatcher import ActionEventGenerator -@action(name="CustomTestAction", is_system_action=True, execute_async=True) -async def custom_test(context: dict, param: int, event_generator: ActionEventGenerator): - for i in range(1, 5): - await asyncio.sleep(5) - await event_generator.send_action_update_event("Test", {f"value {i}": 10}) +@action(name="CustomTestAction") +async def custom_test(context: dict, param: int): return param + context["value"] - - -@action(name="CustomActionWithUpdateEventsAction") -async def custom_action_with_update_events(context: Optional[dict] = None, **kwargs): - return True diff --git a/tests/test_configs/with_custom_action_v2_x/config.co b/tests/test_configs/with_custom_action_v2_x/config.co index b3fd1f9dc..290ee977e 100644 --- a/tests/test_configs/with_custom_action_v2_x/config.co +++ b/tests/test_configs/with_custom_action_v2_x/config.co @@ -2,10 +2,5 @@ flow main global $value $value = 3 match UtteranceUserAction.Finished(final_transcript="start") - start CustomTestAction(param=5) as $action_ref - print $action_ref.uid - match $action_ref.Finished() as $event - $sum = $event.return_value - print $sum - + $sum = await CustomTestAction(param=5) start UtteranceBotAction(script="{$sum}") From 5d9fe06469b10e61d8b0cb010e5bd5e74f6b3905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 11 Dec 2024 14:57:12 +0100 Subject: [PATCH 07/14] Add new unit test for custom action with async events --- .../with_custom_async_action_v2_x/actions.py | 26 ++++ .../with_custom_async_action_v2_x/config.co | 9 ++ .../with_custom_async_action_v2_x/config.yml | 1 + tests/v2_x/chat.py | 4 +- tests/v2_x/test_run_actions.py | 48 +++++++- tests/v2_x/utils.py | 111 ++++++++++++++++++ 6 files changed, 193 insertions(+), 6 deletions(-) create mode 100644 tests/test_configs/with_custom_async_action_v2_x/actions.py create mode 100644 tests/test_configs/with_custom_async_action_v2_x/config.co create mode 100644 tests/test_configs/with_custom_async_action_v2_x/config.yml create mode 100644 tests/v2_x/utils.py diff --git a/tests/test_configs/with_custom_async_action_v2_x/actions.py b/tests/test_configs/with_custom_async_action_v2_x/actions.py new file mode 100644 index 000000000..dd8b17ee0 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_v2_x/actions.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from nemoguardrails.actions import action +from nemoguardrails.actions.action_dispatcher import ActionEventGenerator + + +@action(name="CustomAsyncTestAction", is_system_action=True, execute_async=True) +async def custom_async_test(event_generator: ActionEventGenerator): + for i in range(1, 5): + await asyncio.sleep(1) + event_generator.send_action_updated_event("Value", {"number": i}) diff --git a/tests/test_configs/with_custom_async_action_v2_x/config.co b/tests/test_configs/with_custom_async_action_v2_x/config.co new file mode 100644 index 000000000..7dbacc83b --- /dev/null +++ b/tests/test_configs/with_custom_async_action_v2_x/config.co @@ -0,0 +1,9 @@ +flow main + match UtteranceUserAction.Finished(final_transcript="start") + start CustomAsyncTestAction() as $action_ref + while True: + when $action_ref.ValueUpdated() as $ref: + await UtteranceBotAction(script="Value: {$ref.number}") + or when $action_ref.Finished() as $event: + await UtteranceBotAction(script="End") + break diff --git a/tests/test_configs/with_custom_async_action_v2_x/config.yml b/tests/test_configs/with_custom_async_action_v2_x/config.yml new file mode 100644 index 000000000..ae6347623 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_v2_x/config.yml @@ -0,0 +1 @@ +colang_version: "2.x" diff --git a/tests/v2_x/chat.py b/tests/v2_x/chat.py index 7cdc91d15..4e52b4599 100644 --- a/tests/v2_x/chat.py +++ b/tests/v2_x/chat.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional import nemoguardrails.rails.llm.llmrails -from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails import LLMRails from nemoguardrails.cli.chat import extract_scene_text_content, parse_events_inputs from nemoguardrails.colang.v2_x.runtime.flows import State from nemoguardrails.utils import new_event_dict, new_uuid @@ -44,7 +44,7 @@ class ChatInterface: def __init__(self, rails_app: LLMRails): self.chat_state = ChatState() self.rails_app = rails_app - self.input_queue = asyncio.Queue() + self.input_queue: asyncio.Queue = asyncio.Queue() self.loop = asyncio.get_event_loop() asyncio.create_task(self.run()) diff --git a/tests/v2_x/test_run_actions.py b/tests/v2_x/test_run_actions.py index 914c56d61..45315e8c0 100644 --- a/tests/v2_x/test_run_actions.py +++ b/tests/v2_x/test_run_actions.py @@ -14,11 +14,14 @@ # limitations under the License. import logging +import os from rich.logging import RichHandler from nemoguardrails import RailsConfig +from nemoguardrails.utils import get_or_create_event_loop from tests.utils import TestChat +from tests.v2_x.utils import compare_interaction_with_test_script FORMAT = "%(message)s" logging.basicConfig( @@ -28,8 +31,10 @@ handlers=[RichHandler(markup=True)], ) +CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), "../test_configs") -def test_1(): + +def test_basic_statement(): config = RailsConfig.from_content( colang_content=""" flow user express greeting @@ -62,7 +67,7 @@ async def fetch_name(): chat << "Hello world!" -def test_2(): +def test_short_return_statement(): config = RailsConfig.from_content( colang_content=""" flow user express greeting @@ -95,7 +100,7 @@ async def fetch_name(): chat << "John" -def test_3(): +def test_long_return_statement(): config = RailsConfig.from_content( colang_content=""" flow bot say $text @@ -129,5 +134,40 @@ async def fetch_dictionary(): chat << "I couldn't find any items matching your request!" +def test_custom_action(): + # This config just imports another one, to check that actions are correctly + # loaded. + config = RailsConfig.from_path( + os.path.join(CONFIGS_FOLDER, "with_custom_action_v2_x") + ) + + chat = TestChat( + config, + llm_completions=[], + ) + + chat >> "start" + chat << "8" + + +def test_custom_action_generating_async_events(): + path = os.path.join(CONFIGS_FOLDER, "with_custom_async_action_v2_x") + test_script = """ + > start + Value: 1 + Value: 2 + Value: 3 + Value: 4 + End + """ + + loop = get_or_create_event_loop() + result = loop.run_until_complete( + compare_interaction_with_test_script(test_script, 10.0, colang_path=path) + ) + + assert result is None, result + + if __name__ == "__main__": - test_3() + test_custom_action_generating_async_events() diff --git a/tests/v2_x/utils.py b/tests/v2_x/utils.py new file mode 100644 index 000000000..213536e40 --- /dev/null +++ b/tests/v2_x/utils.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +# TODO: This is a copy of docs/colang-2/examples/utils.py, we should unify +# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + +import asyncio +from typing import Optional + +from nemoguardrails.rails.llm.config import RailsConfig +from nemoguardrails.rails.llm.llmrails import LLMRails +from tests.utils import FakeLLM +from tests.v2_x.chat import ChatInterface + +YAML_CONFIG = """ +colang_version: "2.x" +""" + + +async def run_chat_interface_based_on_test_script( + test_script: str, + wait_time_s: float, + colang: Optional[str] = None, + colang_path: Optional[str] = None, + llm_responses: Optional[list] = None, +) -> str: + rails_config: RailsConfig + if colang: + rails_config = RailsConfig.from_content( + colang_content=colang, + yaml_content=YAML_CONFIG, + ) + elif colang_path: + rails_config = RailsConfig.from_path(colang_path) + + interaction_log = [] + + if llm_responses: + llm = FakeLLM(responses=llm_responses) + rails_app = LLMRails(rails_config, verbose=True, llm=llm) + else: + rails_app = LLMRails(rails_config, verbose=True) + + chat = ChatInterface(rails_app) + + lines = test_script.split("\n") + for line in lines: + line = line.strip() + if line.startswith("#"): + continue + if line.startswith(">"): + interaction_log.append(line) + user_input = line.replace("> ", "") + print(f"sending '{user_input}' to process") + response = await chat.process(user_input, wait_time_s) + interaction_log.append(response) + + chat.should_terminate = True + await asyncio.sleep(0.5) + + return "\n".join(interaction_log) + + +def cleanup(content): + output = [] + lines = content.split("\n") + for line in lines: + if len(line.strip()) == 0: + continue + if line.strip() == ">": + continue + if line.startswith("#"): + continue + if "Starting the chat" in line: + continue + + output.append(line.strip()) + + return "\n".join(output) + + +async def compare_interaction_with_test_script( + test_script: str, + wait_time_s: float = 1.0, + colang: Optional[str] = None, + colang_path: Optional[str] = None, + llm_responses: Optional[list] = None, +) -> Optional[str]: + result = await run_chat_interface_based_on_test_script( + test_script, wait_time_s, colang, colang_path, llm_responses=llm_responses + ) + clean_test_script = cleanup(test_script) + clean_result = cleanup(result) + if clean_test_script == clean_result: + return None + + return f"\n----\n{clean_result}\n----\n\ndoes not match test script\n\n----\n{clean_test_script}\n----" From 7fb8c208e40fead2d9d3dd9e2de3b78a5eb86e6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 11 Dec 2024 15:34:19 +0100 Subject: [PATCH 08/14] Fix test --- tests/test_configs/with_custom_async_action_v2_x/config.co | 4 ++-- tests/v2_x/test_run_actions.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_configs/with_custom_async_action_v2_x/config.co b/tests/test_configs/with_custom_async_action_v2_x/config.co index 7dbacc83b..d95056037 100644 --- a/tests/test_configs/with_custom_async_action_v2_x/config.co +++ b/tests/test_configs/with_custom_async_action_v2_x/config.co @@ -3,7 +3,7 @@ flow main start CustomAsyncTestAction() as $action_ref while True: when $action_ref.ValueUpdated() as $ref: - await UtteranceBotAction(script="Value: {$ref.number}") + start UtteranceBotAction(script="Value: {$ref.number}") or when $action_ref.Finished() as $event: - await UtteranceBotAction(script="End") + start UtteranceBotAction(script="End") break diff --git a/tests/v2_x/test_run_actions.py b/tests/v2_x/test_run_actions.py index 45315e8c0..f7f79920f 100644 --- a/tests/v2_x/test_run_actions.py +++ b/tests/v2_x/test_run_actions.py @@ -159,6 +159,8 @@ def test_custom_action_generating_async_events(): Value: 3 Value: 4 End + Event: StopUtteranceBotAction + Event: StopUtteranceBotAction """ loop = get_or_create_event_loop() From 4545639d0021b80b4777e15d71c38b8344b298a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 11 Dec 2024 15:39:45 +0100 Subject: [PATCH 09/14] Extend test with raw event --- tests/test_configs/with_custom_async_action_v2_x/actions.py | 3 +++ tests/test_configs/with_custom_async_action_v2_x/config.co | 2 ++ tests/v2_x/test_run_actions.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_configs/with_custom_async_action_v2_x/actions.py b/tests/test_configs/with_custom_async_action_v2_x/actions.py index dd8b17ee0..52ceffb07 100644 --- a/tests/test_configs/with_custom_async_action_v2_x/actions.py +++ b/tests/test_configs/with_custom_async_action_v2_x/actions.py @@ -24,3 +24,6 @@ async def custom_async_test(event_generator: ActionEventGenerator): for i in range(1, 5): await asyncio.sleep(1) event_generator.send_action_updated_event("Value", {"number": i}) + await asyncio.sleep(1) + event_generator.send_raw_event("NewCustomUmimEvent", {"secret": "xyz"}) + await asyncio.sleep(1) diff --git a/tests/test_configs/with_custom_async_action_v2_x/config.co b/tests/test_configs/with_custom_async_action_v2_x/config.co index d95056037..354c5c563 100644 --- a/tests/test_configs/with_custom_async_action_v2_x/config.co +++ b/tests/test_configs/with_custom_async_action_v2_x/config.co @@ -4,6 +4,8 @@ flow main while True: when $action_ref.ValueUpdated() as $ref: start UtteranceBotAction(script="Value: {$ref.number}") + or when NewCustomUmimEvent() as $ref: + start UtteranceBotAction(script="Secret: {$ref.secret}") or when $action_ref.Finished() as $event: start UtteranceBotAction(script="End") break diff --git a/tests/v2_x/test_run_actions.py b/tests/v2_x/test_run_actions.py index f7f79920f..758f85f2a 100644 --- a/tests/v2_x/test_run_actions.py +++ b/tests/v2_x/test_run_actions.py @@ -158,9 +158,9 @@ def test_custom_action_generating_async_events(): Value: 2 Value: 3 Value: 4 + Secret: xyz End Event: StopUtteranceBotAction - Event: StopUtteranceBotAction """ loop = get_or_create_event_loop() From 65b71c3f10c6fb801c11a387d219bcdfc627d2b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 11 Dec 2024 15:54:05 +0100 Subject: [PATCH 10/14] Enable chat cli user input during async actions --- nemoguardrails/cli/chat.py | 4 +++- tests/v2_x/chat.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/nemoguardrails/cli/chat.py b/nemoguardrails/cli/chat.py index c48997319..dd9f2bf9b 100644 --- a/nemoguardrails/cli/chat.py +++ b/nemoguardrails/cli/chat.py @@ -498,7 +498,9 @@ async def _process_input_events(): chat_state.input_events = [] else: chat_state.waiting_user_input = True - await enable_input.wait() + # NOTE: We should never disable the user input since we can have + # async Python actions running in parallel + # await enable_input.wait() user_message: str = await chat_state.session.prompt_async( HTML("\n> "), diff --git a/tests/v2_x/chat.py b/tests/v2_x/chat.py index 4e52b4599..f46f91221 100644 --- a/tests/v2_x/chat.py +++ b/tests/v2_x/chat.py @@ -322,7 +322,9 @@ async def run(self): self.chat_state.input_events = [] else: self.chat_state.waiting_user_input = True - await self.enable_input.wait() + # NOTE: We should never disable the user input since we can have + # async Python actions running in parallel + # await enable_input.wait() user_message = "" if not self.input_queue.empty(): From d537507f45096d0e2e799bff6e1aa4f6cfc352d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Mon, 16 Dec 2024 14:00:55 +0100 Subject: [PATCH 11/14] Refactor event handling in runtime and many related things --- nemoguardrails/actions/action_dispatcher.py | 55 +-- nemoguardrails/colang/runtime.py | 6 +- nemoguardrails/colang/v1_0/runtime/runtime.py | 14 +- nemoguardrails/colang/v2_x/runtime/flows.py | 8 +- nemoguardrails/colang/v2_x/runtime/runtime.py | 384 +++++++++++++----- nemoguardrails/logging/verbose.py | 4 + nemoguardrails/rails/llm/llmrails.py | 41 +- .../actions.py | 42 ++ .../config.co | 16 + .../config.yml | 0 .../actions.py | 21 +- .../config.co | 21 + .../config.yml | 1 + .../with_custom_async_action_v2_x/config.co | 11 - tests/test_embeddings_only_user_messages.py | 4 +- tests/test_retrieve_relevant_chunks.py | 12 +- tests/v2_x/chat.py | 3 +- tests/v2_x/test_run_actions.py | 11 +- 18 files changed, 447 insertions(+), 207 deletions(-) create mode 100644 tests/test_configs/with_custom_async_action_events_v2_x/actions.py create mode 100644 tests/test_configs/with_custom_async_action_events_v2_x/config.co rename tests/test_configs/{with_custom_async_action_v2_x => with_custom_async_action_events_v2_x}/config.yml (100%) rename tests/test_configs/{with_custom_async_action_v2_x => with_custom_async_action_lifetime_v2_x}/actions.py (59%) create mode 100644 tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co create mode 100644 tests/test_configs/with_custom_async_action_lifetime_v2_x/config.yml delete mode 100644 tests/test_configs/with_custom_async_action_v2_x/config.co diff --git a/nemoguardrails/actions/action_dispatcher.py b/nemoguardrails/actions/action_dispatcher.py index c6638fad9..5e4a09386 100644 --- a/nemoguardrails/actions/action_dispatcher.py +++ b/nemoguardrails/actions/action_dispatcher.py @@ -15,7 +15,6 @@ """Module for the calling proper action endpoints based on events received at action server endpoint""" -import asyncio import importlib.util import inspect import logging @@ -28,55 +27,17 @@ from nemoguardrails import utils from nemoguardrails.actions.llm.utils import LLMCallException -from nemoguardrails.colang.v2_x.runtime.flows import Action, Event from nemoguardrails.logging.callbacks import logging_callbacks -from nemoguardrails.rails.llm.config import RailsConfig log = logging.getLogger(__name__) -class ActionEventGenerator: - """Generator to emit event from async Python actions.""" - - def __init__( - self, - config: RailsConfig, - action: Action, - action_event_queue: asyncio.Queue[dict], - ): - # The LLMRails config - self._config = config - - # The relevant action - self._action = action - - # Contains reference to the async action event queue - self._action_events_queue = action_event_queue - - def send_action_updated_event(self, event_name: str, args: dict) -> None: - """Send an Action*Updated event.""" - action_event = self._action.updated_event( - {"event_parameter_name": event_name, **args} - ) - self._action_events_queue.put_nowait( - action_event.to_umim_event(self._config.event_source_uid) - ) - - def send_raw_event(self, event_name: str, args: dict) -> None: - """Send any event.""" - event = Event(event_name, args) - self._action_events_queue.put_nowait( - event.to_umim_event(self._config.event_source_uid) - ) - - class ActionDispatcher: def __init__( self, load_all_actions: bool = True, config_path: Optional[str] = None, import_paths: Optional[List[str]] = None, - action_event_queue: Optional[asyncio.Queue[dict]] = None, ): """ Initializes an actions dispatcher. @@ -93,9 +54,6 @@ def __init__( # Dictionary with all registered actions self._registered_actions: dict = {} - # Contains generated events form async actions - self._async_action_events: Optional[asyncio.Queue[dict]] = action_event_queue - if load_all_actions: # TODO: check for better way to find actions dir path or use constants.py current_file_path = Path(__file__).resolve() @@ -224,7 +182,7 @@ def get_action(self, name: str) -> callable: async def execute_action( self, action_name: str, params: Dict[str, Any] - ) -> Tuple[Union[str, Dict[str, Any]], str]: + ) -> Tuple[Optional[Union[str, Dict[str, Any]]], str]: """Execute a registered action. Args: @@ -299,15 +257,12 @@ async def execute_action( filtered_params = { k: v for k, v in params.items() - if k not in ["state", "events", "llm"] + if k not in ["state", "events", "llm", "event_handler"] } - log.warning( - "Error while execution '%s' with parameters '%s': %s", - action_name, - filtered_params, - e, + msg = ( + f"Exception while execution '{action_name}' with parameters '{filtered_params}'", ) - log.exception(e) + raise Exception(f"{msg}: {e}") from e return None, "failed" diff --git a/nemoguardrails/colang/runtime.py b/nemoguardrails/colang/runtime.py index 71db5915f..98c5b2b93 100644 --- a/nemoguardrails/colang/runtime.py +++ b/nemoguardrails/colang/runtime.py @@ -32,14 +32,10 @@ def __init__(self, config: RailsConfig, verbose: bool = False): self.config = config self.verbose = verbose - # Contains generated events form async Python actions that should be processed - self._async_action_events: asyncio.Queue[dict] = asyncio.Queue() - # Register the actions with the dispatcher. self.action_dispatcher = ActionDispatcher( config_path=config.config_path, import_paths=list(config.imported_paths.values()), - action_event_queue=self._async_action_events, ) # The list of additional parameters that can be passed to the actions. @@ -52,7 +48,7 @@ def __init__(self, config: RailsConfig, verbose: bool = False): # A set of watchers that are notified every time an event is processed. # Used mainly for reporting the progress to the CLI. - self.watchers = [] + self.watchers: List = [] # The maximum number of events to be processed in a processing loop self.max_events = 500 diff --git a/nemoguardrails/colang/v1_0/runtime/runtime.py b/nemoguardrails/colang/v1_0/runtime/runtime.py index 56fa00efc..e3d426706 100644 --- a/nemoguardrails/colang/v1_0/runtime/runtime.py +++ b/nemoguardrails/colang/v1_0/runtime/runtime.py @@ -25,6 +25,7 @@ from langchain.chains.base import Chain from nemoguardrails.actions.actions import ActionResult +from nemoguardrails.actions.llm.utils import LLMCallException from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.runtime import Runtime from nemoguardrails.colang.v1_0.runtime.flows import ( @@ -360,9 +361,16 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]: kwargs["llm"] = self.registered_action_params[f"{action_name}_llm"] log.info("Executing action :: %s", action_name) - result, status = await self.action_dispatcher.execute_action( - action_name, kwargs - ) + try: + result, status = await self.action_dispatcher.execute_action( + action_name, kwargs + ) + except LLMCallException as e: + raise e + except Exception as e: + result = None + status = "failed" + log.exception(e) # If the action execution failed, we return a hardcoded message if status == "failed": diff --git a/nemoguardrails/colang/v2_x/runtime/flows.py b/nemoguardrails/colang/v2_x/runtime/flows.py index 7a480859c..88847eecc 100644 --- a/nemoguardrails/colang/v2_x/runtime/flows.py +++ b/nemoguardrails/colang/v2_x/runtime/flows.py @@ -228,7 +228,13 @@ def from_event(cls, event: ActionEvent) -> Optional[Action]: assert event.action_uid is not None for name in cls._event_name_map: if name in event.name: - action = Action(event.name.replace(name, ""), {}) + action_name: str + if name == "Updated": + index = event.name.find("Action") + 6 + action_name = event.name[:index] + else: + action_name = event.name.replace(name, "") + action = Action(action_name, {}) action.uid = event.action_uid action.status = ( ActionStatus.STARTED diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index c74577c6d..c43522432 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -16,14 +16,14 @@ import inspect import logging import re -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from urllib.parse import urljoin import aiohttp import langchain from langchain.chains.base import Chain -from nemoguardrails.actions.action_dispatcher import ActionEventGenerator from nemoguardrails.actions.actions import ActionResult from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.runtime import Runtime @@ -56,23 +56,132 @@ log = logging.getLogger(__name__) +class ActionEventHandler: + """Handles input and output events to Python actions.""" + + _lock = asyncio.Lock() + + def __init__( + self, + config: RailsConfig, + action: Action, + event_input_queue: asyncio.Queue[dict], + event_output_queue: asyncio.Queue[dict], + ): + # The LLMRails config + self._config = config + + # The relevant action + self._action = action + + # Action specific action event queue for event receiving + self._event_input_queue = event_input_queue + + # Shared async action event queue for event sending + self._event_output_queue = event_output_queue + + def send_action_updated_event( + self, event_name: str, args: Optional[dict] = None + ) -> None: + """ + Send an Action*Updated event. + + Args: + event_name (str): The name of the action event + args (Optional[dict]): An optional dictionary with the event arguments + """ + + if args: + args = {"event_parameter_name": event_name, **args} + else: + args = {"event_parameter_name": event_name} + action_event = self._action.updated_event(args) + self._event_output_queue.put_nowait( + action_event.to_umim_event(self._config.event_source_uid) + ) + + def send_event(self, event_name: str, args: Optional[dict] = None) -> None: + """ + Send any event. + + Args: + event_name (str): The event name + args (Optional[dict]): An optional dictionary with the event arguments + """ + event = Event(event_name, args if args else {}) + self._event_output_queue.put_nowait( + event.to_umim_event(self._config.event_source_uid) + ) + + async def wait_for_events( + self, event_name: Optional[str] = None, timeout: Optional[float] = None + ) -> List[dict]: + """ + Waits for new input events to process. + + Args: + event_name (Optional[str]): Optional event name to filter for, if None all events will be received + timeout (Optional[float]): The time to wait for new events before it continues + """ + events: List[dict] = [] + keep_waiting = True + while keep_waiting: + try: + # Wait for new events + event = await asyncio.wait_for(self._event_input_queue.get(), timeout) + # Gather all new events + while True: + if event_name is None or event["type"] == event_name: + events.append(event) + event = self._event_input_queue.get_nowait() + except asyncio.QueueEmpty: + self._event_input_queue.task_done() + keep_waiting = len(events) == 0 + except asyncio.TimeoutError: + # Timeout occurred, stop consuming + keep_waiting = False + return events + + +@dataclass +class LocalActionData: + """Structure to help organize action related data.""" + + # All active async action task + task: asyncio.Task + # The action's output event queue + input_event_queues: asyncio.Queue[dict] = field( + default_factory=lambda: asyncio.Queue() + ) + + +@dataclass +class LocalActionGroup: + """Structure to help organize all local actions related to a certain main flow.""" + + # Action uid ordered action data + action_data: Dict[str, LocalActionData] = field(default_factory=dict) + + # A single output event queue for all actions + output_event_queue: asyncio.Queue = field(default_factory=lambda: asyncio.Queue()) + + class RuntimeV2_x(Runtime): """Runtime for executing the guardrails.""" def __init__(self, config: RailsConfig, verbose: bool = False): super().__init__(config, verbose) - # Register local system actions - self.register_action(self._add_flows_action, "AddFlowsAction", False) - self.register_action(self._remove_flows_action, "RemoveFlowsAction", False) - - # Maps main_flow.uid to a dictionary of actions that are run locally, asynchronously. - # Dict[main_flow_uid, Dict[action_uid, action_data]] - self.async_actions: Dict[str, List] = {} + # Maps main_flow.uid to a list of action group data that contains all the locally running actions. + self.local_actions: Dict[str, LocalActionGroup] = {} # A way to disable async function execution. Useful for testing. self.disable_async_execution = False + # Register local system actions + self.register_action(self._add_flows_action, "AddFlowsAction", False) + self.register_action(self._remove_flows_action, "RemoveFlowsAction", False) + async def _add_flows_action(self, state: "State", **args: dict) -> List[str]: log.info("Start AddFlowsAction! %s", args) flow_content = args["config"] @@ -165,6 +274,7 @@ async def generate_events( @staticmethod def _internal_error_action_result(message: str) -> ActionResult: """Helper to construct an action result for an internal error.""" + # TODO: We should handle this as an ActionFinished(is_success=False) event and not generate custom other events return ActionResult( events=[ { @@ -189,13 +299,15 @@ async def _process_start_action( ) -> Tuple[Any, List[dict], dict]: """Starts the specified action, waits for it to finish and posts back the result.""" + return_value: Any = None + return_events: List[dict] = [] + context_updates: dict = {} + fn = self.action_dispatcher.get_action(action.name) # TODO: check action is available in action server if fn is None: - result = self._internal_error_action_result( - f"Action '{action.name}' not found." - ) + raise ColangRuntimeError(f"Action '{action.name}' not found.") else: # We pass all the parameters that are passed explicitly to the action. kwargs = {**action.start_event_arguments} @@ -244,9 +356,16 @@ async def _process_start_action( if "events" in parameters: kwargs["events"] = state.last_events - if "event_generator" in parameters: - kwargs["event_generator"] = ActionEventGenerator( - self.config, action, self._async_action_events + if "event_handler" in parameters: + kwargs["event_handler"] = ActionEventHandler( + self.config, + action, + self.local_actions[state.main_flow_state.uid] + .action_data[action.uid] + .input_event_queues, + self.local_actions[ + state.main_flow_state.uid + ].output_event_queue, ) if "action" in parameters: @@ -282,14 +401,20 @@ async def _process_start_action( # If the action execution failed, we return a hardcoded message if status == "failed": - # TODO: make this message configurable. - result = self._internal_error_action_result( - "I'm sorry, an internal error has occurred." + action_finished_event = self._get_action_finished_event( + self.config, + action, + status="failed", + is_success=False, + failure_reason="Local action finished with an exception!", ) + return_events.append(action_finished_event) - return_value: Any = result - return_events: List[dict] = [] - context_updates: dict = {} + # result = self._internal_error_action_result( + # "I'm sorry, an internal error has occurred." + # ) + + return_value = result if isinstance(result, ActionResult): return_value = result.return_value @@ -298,9 +423,6 @@ async def _process_start_action( if result.context_updates is not None: context_updates.update(result.context_updates) - # if return_events: - # next_steps.extend(return_events) - return return_value, return_events, context_updates async def _get_action_resp( @@ -374,15 +496,14 @@ def _get_action_finished_event( return event.to_umim_event(rails_config.event_source_uid) - async def _get_async_action_events(self) -> List[dict]: + async def _get_async_action_events(self, main_flow_uid: str) -> List[dict]: events = [] while True: try: # Attempt to get an item from the queue without waiting - event = self._async_action_events.get_nowait() - - # assert isinstance(event, dict), "Python action events must be a dictionary!" - # {'type': 'CheckLocalAsync', 'uid': '486a5628-843b-4ed3-b8b3-315ef01c9a13', 'event_created_at': '2024-12-09T14:55:50.847199+00:00', 'source_uid': 'NeMoGuardrails'} + event = self.local_actions[ + main_flow_uid + ].output_event_queue.get_nowait() events.append(event) except asyncio.QueueEmpty: @@ -403,10 +524,13 @@ async def _get_async_actions_finished_events( The array of *ActionFinished events and the pending counter """ - pending_actions = self.async_actions.get(main_flow_uid, []) - if len(pending_actions) == 0: + local_action_group = self.local_actions[main_flow_uid] + if len(local_action_group.action_data) == 0: return [], 0 + pending_actions = [ + data.task for data in local_action_group.action_data.values() + ] done, pending = await asyncio.wait( pending_actions, return_when=asyncio.FIRST_COMPLETED, @@ -417,21 +541,36 @@ async def _get_async_actions_finished_events( action_finished_events = [] for finished_task in done: + action = finished_task.action # type: ignore try: result = finished_task.result() - except Exception: - log.warning( - "Local action finished with an exception!", - exc_info=True, + # We need to create the corresponding action finished event + action_finished_event = self._get_action_finished_event( + self.config, action, **result ) - - self.async_actions[main_flow_uid].remove(finished_task) - - # We need to create the corresponding action finished event - action_finished_event = self._get_action_finished_event( - self.config, **result - ) - action_finished_events.append(action_finished_event) + action_finished_events.append(action_finished_event) + except asyncio.CancelledError: + action_finished_event = self._get_action_finished_event( + self.config, + action, + status="failed", + is_success=False, + was_stopped=True, + failure_reason="stopped", + ) + action_finished_events.append(action_finished_event) + except Exception as e: + msg = "Local action finished with an exception!" + log.warning("%s %s", msg, e) + action_finished_event = self._get_action_finished_event( + self.config, + action, + status="failed", + is_success=False, + failure_reason=msg, + ) + action_finished_events.append(action_finished_event) + del self.local_actions[main_flow_uid].action_data[action.uid] return action_finished_events, len(pending) @@ -450,8 +589,8 @@ async def process_events( The events will be processed one by one, in the input order. If new events are generated as part of the processing, they will be appended to the input events. - By default, a processing cycle only waits for the local actions to finish, i.e, - if after processing all the input events, there are local actions in progress, the + By default, a processing cycle only waits for the non-async local actions to finish, i.e, + if after processing all the input events, there are non-async local actions in progress, the event processing will wait for them to finish. In blocking mode, the event processing will also wait for the local async actions. @@ -474,8 +613,15 @@ async def process_events( input_events: List[Union[dict, InternalEvent]] = [] local_running_actions: List[asyncio.Task[dict]] = [] - input_events.extend(events) + def extend_input_events(events: Sequence[Union[dict, InternalEvent]]): + """Make sure to add all new input events to all local async action event queues.""" + input_events.extend(events) + for data in self.local_actions[main_flow_uid].action_data.values(): + for event in events: + if isinstance(event, dict) and event["type"] != "CheckLocalAsync": + data.input_event_queues.put_nowait(event) + # Initialize empty state if state is None or state == {}: state = State( flow_states={}, flow_configs=self.flow_configs, rails_config=self.config @@ -497,6 +643,7 @@ async def process_events( input_event = InternalEvent(name="StartFlow", arguments={"flow_id": "main"}) input_events.insert(0, input_event) main_flow_state = state.flow_id_states["main"][-1] + self.local_actions[main_flow_state.uid] = LocalActionGroup() # Start all module level flows before main flow idx = 0 @@ -519,19 +666,26 @@ async def process_events( idx += 1 # Check if we have new async action events to add - input_events.extend(await self._get_async_action_events()) + new_events = await self._get_async_action_events(state.main_flow_state.uid) + extend_input_events(new_events) + output_events.extend(new_events) # Check if we have new finished async local action events to add ( local_action_finished_events, pending_local_action_counter, ) = await self._get_async_actions_finished_events(main_flow_uid) - input_events.extend(local_action_finished_events) + extend_input_events(local_action_finished_events) + output_events.extend(local_action_finished_events) + local_action_finished_events = [] return_local_async_action_count = False - # While we have input events to process, or there are local running actions - # we continue the processing. + # Add all input events + extend_input_events(events) + + # While we have input events to process, or there are local + # (non-async) running actions we continue the processing. events_counter = 0 while input_events or local_running_actions: while input_events: @@ -546,7 +700,11 @@ async def process_events( log.info("Processing event :: %s", event) for watcher in self.watchers: - watcher(event) + if ( + not isinstance(event, dict) + or event["type"] != "CheckLocalAsync" + ): + watcher(event) event_name = event["type"] if isinstance(event, dict) else event.name @@ -557,39 +715,19 @@ async def process_events( # Record the event that we're about to process state.last_events.append(event) - # Advance the state machine - new_event: Optional[Union[dict, Event]] = event - while new_event: - try: - run_to_completion(state, new_event) - new_event = None - except Exception as e: - log.warning("Colang runtime error!", exc_info=True) - new_event = Event( - name="ColangError", - arguments={ - "type": str(type(e).__name__), - "error": str(e), - }, - ) - await asyncio.sleep(0.001) - - for out_event in state.outgoing_events: - # We also record the out events in the recent history. - state.last_events.append(out_event) - - # We need to check if we need to run a locally registered action - start_action_match = re.match(r"Start(.*Action)", out_event["type"]) - if start_action_match: - action_event = ActionEvent.from_umim_event(out_event) + # Check if we need run a locally registered action + if isinstance(event, dict): + if re.match(r"Start(.*Action)", event["type"]): + action_event = ActionEvent.from_umim_event(event) action = Action.from_event(action_event) assert action # If it's an instant action, we finish it right away. + # TODO (schuellc): What is this needed for? if instant_actions and action.name in instant_actions: extra = {"action": action} if action.name == "UtteranceBotAction": - extra["final_script"] = out_event["script"] + extra["final_script"] = event["script"] action_finished_event = self._get_action_finished_event( self.config, **extra @@ -598,8 +736,8 @@ async def process_events( # We send the completion of the action as an output event # and continue processing it. # TODO: Why do we need an output event for that? It should only be an new input event + extend_input_events([action_finished_event]) output_events.append(action_finished_event) - input_events.append(action_finished_event) elif self.action_dispatcher.has_registered(action.name): # In this case we need to start the action locally @@ -615,6 +753,8 @@ async def process_events( state=state, ) ) + # Attach related action to the task + local_action.action = action # type: ignore # Generate *ActionStarted event action_started_event = action.started_event({}) @@ -623,8 +763,8 @@ async def process_events( self.config.event_source_uid ) ) - # output_events.append(action_started_umim_event) - input_events.append(action_started_umim_event) + extend_input_events([action_started_umim_event]) + output_events.append(action_started_umim_event) # If the function is not async, or async execution is disabled # we execute the actions as a local action. @@ -638,30 +778,68 @@ async def process_events( local_running_actions.append(local_action) else: main_flow_uid = state.main_flow_state.uid - if main_flow_uid not in self.async_actions: - self.async_actions[main_flow_uid] = [] - self.async_actions[main_flow_uid].append(local_action) - else: - output_events.append(out_event) - else: - output_events.append(out_event) + if main_flow_uid not in self.local_actions: + # TODO: This check should not be needed + self.local_actions[ + main_flow_uid + ] = LocalActionGroup() + self.local_actions[main_flow_uid].action_data.update( + {action.uid: LocalActionData(local_action)} + ) + elif re.match(r"Stop(.*Action)", event["type"]): + # Check if we need stop a locally running action + action_event = ActionEvent.from_umim_event(event) + action_uid = action_event.arguments.get("action_uid", None) + if action_uid: + data = self.local_actions[main_flow_uid].action_data.get( + action_uid + ) + if ( + data + and data.task.action.name # type: ignore + == action_event.name[4:] + ): + data.task.cancel() + + # Advance the state machine + new_event: Optional[Union[dict, Event]] = event + while new_event: + try: + run_to_completion(state, new_event) + new_event = None + except Exception as e: + log.warning("Colang runtime error!", exc_info=True) + new_event = Event( + name="ColangError", + arguments={ + "type": str(type(e).__name__), + "error": str(e), + }, + ) + # Give local async action the chance to process events + await asyncio.sleep(0.001) # Add new async action events as new input events - input_events.extend(await self._get_async_action_events()) + new_events = await self._get_async_action_events( + state.main_flow_state.uid + ) + extend_input_events(new_events) + output_events.extend(new_events) # Add new finished async local action events as new input events ( new_action_finished_events, pending_local_action_counter, ) = await self._get_async_actions_finished_events(main_flow_uid) - input_events.extend(new_action_finished_events) + extend_input_events(new_action_finished_events) + output_events.extend(new_action_finished_events) # Add generated events as new input events - input_events.extend(state.outgoing_events) + extend_input_events(state.outgoing_events) + output_events.extend(state.outgoing_events) - # If we have any local running actions, we need to wait for at least one + # If we have any non-async local running actions, we need to wait for at least one # of them to finish. - # TODO: Check why we should wait for one to finish?! if local_running_actions: log.info( "Waiting for %d local actions to finish.", @@ -678,7 +856,7 @@ async def process_events( # We need to create the corresponding action finished event action_finished_event = self._get_action_finished_event( - self.config, **result + self.config, finished_task.action, **result # type: ignore ) input_events.append(action_finished_event) @@ -694,12 +872,25 @@ async def process_events( ) ) - # TODO: serialize the state to dict - # We cap the recent history to the last 500 state.last_events = state.last_events[-500:] - return output_events, state + if state.main_flow_state.status == FlowStatus.FINISHED: + log.info("End of story!") + del self.local_actions[main_flow_uid] + + # We currently filter out all events related local actions + # TODO: Consider if we should expose them all as umim events + final_output_events = [] + for event in output_events: + if isinstance(event, dict) and "action_uid" in event: + action_event = ActionEvent.from_umim_event(event) + action = Action.from_event(action_event) + if action and self.action_dispatcher.has_registered(action.name): + continue + final_output_events.append(event) + + return final_output_events, state async def _run_action( self, @@ -722,7 +913,6 @@ async def _run_action( state.context.update(context_updates) return { - "action": action, "return_value": return_value, "new_events": new_events, "context_updates": context_updates, diff --git a/nemoguardrails/logging/verbose.py b/nemoguardrails/logging/verbose.py index a2f972238..747075cb2 100644 --- a/nemoguardrails/logging/verbose.py +++ b/nemoguardrails/logging/verbose.py @@ -163,6 +163,10 @@ def emit(self, record) -> None: msg += f"[dim]{title}[/]" console.print(msg, highlight=False, no_wrap=False) + elif record.levelno >= logging.WARNING: + current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] + msg = f"[dim]{current_time}[/] | [yellow bold]Warning[/] | " + msg + console.print(msg, highlight=False, no_wrap=False) def set_verbose( diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index b24fcb99c..e108d6edc 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -34,7 +34,7 @@ from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.v1_0.runtime.flows import compute_context from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0 -from nemoguardrails.colang.v2_x.runtime.flows import Action, State +from nemoguardrails.colang.v2_x.runtime.flows import Action, ActionEvent, State from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x from nemoguardrails.colang.v2_x.runtime.serialization import ( json_to_state, @@ -732,24 +732,27 @@ async def generate_async( start_action_match = re.match(r"Start(.*Action)", event["type"]) if start_action_match: - action_name = start_action_match[1] - # TODO: is there an elegant way to extract just the arguments? - arguments = { - k: v - for k, v in event.items() - if k != "type" - and k != "uid" - and k != "event_created_at" - and k != "source_uid" - and k != "action_uid" - } - response_tool_calls.append( - { - "id": event["action_uid"], - "type": "function", - "function": {"name": action_name, "arguments": arguments}, - } - ) + action_event = ActionEvent.from_umim_event(event) + action = Action.from_event(action_event) + if action: + # TODO (schuellc): Check why we need this? + # Also it seems we need to exclude the following actions + if action.name in [ + "UtteranceBotAction", + "GestureBotAction", + "PostureBotAction", + ]: + continue + response_tool_calls.append( + { + "id": action.uid, + "type": "function", + "function": { + "name": action.name, + "arguments": action.start_event_arguments, + }, + } + ) elif event["type"] == "UtteranceBotActionFinished": responses.append(event["final_script"]) diff --git a/tests/test_configs/with_custom_async_action_events_v2_x/actions.py b/tests/test_configs/with_custom_async_action_events_v2_x/actions.py new file mode 100644 index 000000000..451b661db --- /dev/null +++ b/tests/test_configs/with_custom_async_action_events_v2_x/actions.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from nemoguardrails.actions import action +from nemoguardrails.colang.v2_x.runtime.runtime import ActionEventHandler + + +@action(name="CustomAsyncTest1Action", is_system_action=True, execute_async=True) +async def custom_async_test1(event_handler: ActionEventHandler): + for i in range(1, 3): + await asyncio.sleep(1) + event_handler.send_action_updated_event("Value", {"number": i}) + await asyncio.sleep(1) + event_handler.send_event("CustomEventA", {"value": "A"}) + events = await event_handler.wait_for_events("CustomEventB") + print("-----------> CustomEventB") + event_handler.send_event("CustomEventResponse", {"value": events[0]["value"]}) + await event_handler.wait_for_events("CustomEventC") + await asyncio.sleep(3) + # raise Exception("Python action exception!") + + +@action(name="CustomAsyncTest2Action", is_system_action=True, execute_async=True) +async def custom_async_test2(event_handler: ActionEventHandler): + await event_handler.wait_for_events("CustomEventResponse") + print("-----------> CustomEventResponse") + await asyncio.sleep(3) + event_handler.send_event("CustomEventC", {"value": "C"}) diff --git a/tests/test_configs/with_custom_async_action_events_v2_x/config.co b/tests/test_configs/with_custom_async_action_events_v2_x/config.co new file mode 100644 index 000000000..21b662125 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_events_v2_x/config.co @@ -0,0 +1,16 @@ +flow main + match UtteranceUserAction.Finished(final_transcript="start") + start CustomAsyncTest1Action() as $action1_ref + start CustomAsyncTest2Action() + match $action1_ref.Started() + while True: + when $action1_ref.ValueUpdated() as $ref: + start UtteranceBotAction(script="Value: {$ref.number}") + or when CustomEventA() as $ref: + start UtteranceBotAction(script="Value: {$ref.value}") + send CustomEventB(value="B") + match CustomEventC(value="C") + start UtteranceBotAction(script="Check") + or when $action1_ref.Finished() as $event: + start UtteranceBotAction(script="End") + break diff --git a/tests/test_configs/with_custom_async_action_v2_x/config.yml b/tests/test_configs/with_custom_async_action_events_v2_x/config.yml similarity index 100% rename from tests/test_configs/with_custom_async_action_v2_x/config.yml rename to tests/test_configs/with_custom_async_action_events_v2_x/config.yml diff --git a/tests/test_configs/with_custom_async_action_v2_x/actions.py b/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py similarity index 59% rename from tests/test_configs/with_custom_async_action_v2_x/actions.py rename to tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py index 52ceffb07..eda7e1d90 100644 --- a/tests/test_configs/with_custom_async_action_v2_x/actions.py +++ b/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py @@ -16,14 +16,19 @@ import asyncio from nemoguardrails.actions import action -from nemoguardrails.actions.action_dispatcher import ActionEventGenerator +from nemoguardrails.colang.v2_x.runtime.runtime import ActionEventHandler -@action(name="CustomAsyncTestAction", is_system_action=True, execute_async=True) -async def custom_async_test(event_generator: ActionEventGenerator): - for i in range(1, 5): - await asyncio.sleep(1) - event_generator.send_action_updated_event("Value", {"number": i}) - await asyncio.sleep(1) - event_generator.send_raw_event("NewCustomUmimEvent", {"secret": "xyz"}) +@action(name="Test1Action", is_system_action=True, execute_async=True) +async def test1(): await asyncio.sleep(1) + + +@action(name="Test2Action", is_system_action=True, execute_async=True) +async def test2(event_handler: ActionEventHandler): + await event_handler.wait_for_events("NeverHappeningEver") + + +@action(name="Test3Action", is_system_action=True, execute_async=True) +async def test3(event_handler: ActionEventHandler): + raise Exception("Issue occurred!") diff --git a/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co new file mode 100644 index 000000000..f4559baa7 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co @@ -0,0 +1,21 @@ +flow main + match UtteranceUserAction.Finished(final_transcript="start") + + # Normal action lifetime + start Test1Action() as $action1_ref + match $action1_ref.Started() + match $action1_ref.Finished() as $ref + start UtteranceBotAction(script="Result: {$ref.is_success}") + + # Action gets canceled + start Test2Action() as $action2_ref + match $action2_ref.Started() + send StopTest2Action(action_uid=$action2_ref.uid) + match $action2_ref.Finished() as $ref + start UtteranceBotAction(script="Result: {$ref.was_stopped}") + + # Action raises Exception + start Test3Action() as $action3_ref + match $action3_ref.Started() + match $action3_ref.Finished() as $ref + start UtteranceBotAction(script="Result: {$ref.failure_reason}") diff --git a/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.yml b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.yml new file mode 100644 index 000000000..ae6347623 --- /dev/null +++ b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.yml @@ -0,0 +1 @@ +colang_version: "2.x" diff --git a/tests/test_configs/with_custom_async_action_v2_x/config.co b/tests/test_configs/with_custom_async_action_v2_x/config.co deleted file mode 100644 index 354c5c563..000000000 --- a/tests/test_configs/with_custom_async_action_v2_x/config.co +++ /dev/null @@ -1,11 +0,0 @@ -flow main - match UtteranceUserAction.Finished(final_transcript="start") - start CustomAsyncTestAction() as $action_ref - while True: - when $action_ref.ValueUpdated() as $ref: - start UtteranceBotAction(script="Value: {$ref.number}") - or when NewCustomUmimEvent() as $ref: - start UtteranceBotAction(script="Secret: {$ref.secret}") - or when $action_ref.Finished() as $event: - start UtteranceBotAction(script="End") - break diff --git a/tests/test_embeddings_only_user_messages.py b/tests/test_embeddings_only_user_messages.py index 9dd428755..6dfc25ef3 100644 --- a/tests/test_embeddings_only_user_messages.py +++ b/tests/test_embeddings_only_user_messages.py @@ -191,7 +191,7 @@ def test_examples_included_in_prompts_2(colang_2_config): colang_2_config, llm_completions=[ " user express greeting", - ' bot respond to uknown intent "Hello is there anything else" ', + ' bot respond to unknown intent\nbot action: bot say "Hello is there anything else"', ], ) @@ -214,7 +214,7 @@ def test_no_llm_calls_embedding_only(colang_2_config): colang_2_config, llm_completions=[ " user express greeting", - ' bot respond to uknown intent "Hello is there anything else" ', + ' bot respond to unknown intent "Hello is there anything else" ', ], ) diff --git a/tests/test_retrieve_relevant_chunks.py b/tests/test_retrieve_relevant_chunks.py index 7d1044661..f4e110e2e 100644 --- a/tests/test_retrieve_relevant_chunks.py +++ b/tests/test_retrieve_relevant_chunks.py @@ -29,12 +29,12 @@ activate llm continuation flow user express greeting - user said "hello" - or user said "hi" - or user said "how are you" + user said "hello" + or user said "hi" + or user said "how are you" flow bot express greeting - bot say "Hey!" + bot say "Hey!" flow greeting user express greeting @@ -58,7 +58,7 @@ def test_relevant_chunk_inserted_in_prompt(): config, llm_completions=[ " user express greeting", - ' bot respond to aditional context\nbot action: "Hello is there anything else" ', + ' bot respond to additional context\nbot action: bot say "Hello is there anything else" ', ], ) @@ -85,7 +85,7 @@ def test_relevant_chunk_inserted_in_prompt_no_kb(): config, llm_completions=[ " user express greeting", - ' bot respond to aditional context\nbot action: "Hello is there anything else" ', + ' bot respond to aditional context\nbot action: bot say "Hello is there anything else" ', ], ) rails = chat.app diff --git a/tests/v2_x/chat.py b/tests/v2_x/chat.py index f46f91221..584d2f5cb 100644 --- a/tests/v2_x/chat.py +++ b/tests/v2_x/chat.py @@ -324,7 +324,8 @@ async def run(self): self.chat_state.waiting_user_input = True # NOTE: We should never disable the user input since we can have # async Python actions running in parallel - # await enable_input.wait() + # TODO: Check if disabling causes race conditions + await self.enable_input.wait() user_message = "" if not self.input_queue.empty(): diff --git a/tests/v2_x/test_run_actions.py b/tests/v2_x/test_run_actions.py index 758f85f2a..47ac73ac2 100644 --- a/tests/v2_x/test_run_actions.py +++ b/tests/v2_x/test_run_actions.py @@ -151,14 +151,17 @@ def test_custom_action(): def test_custom_action_generating_async_events(): - path = os.path.join(CONFIGS_FOLDER, "with_custom_async_action_v2_x") + path = os.path.join(CONFIGS_FOLDER, "with_custom_async_action_events_v2_x") test_script = """ > start Value: 1 Value: 2 - Value: 3 - Value: 4 - Secret: xyz + Event: CustomEventA + Value: A + Event: CustomEventB + Event: CustomEventResponse + Event: CustomEventC + Check End Event: StopUtteranceBotAction """ From 436bc04052210407ea8a564d3de4e5bdea59a376 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Mon, 16 Dec 2024 16:33:05 +0100 Subject: [PATCH 12/14] Extend unit tests --- nemoguardrails/actions/action_dispatcher.py | 2 + nemoguardrails/colang/v2_x/runtime/flows.py | 2 +- nemoguardrails/colang/v2_x/runtime/runtime.py | 46 ++++++++++++++++++- .../actions.py | 2 - .../actions.py | 13 +++++- .../config.co | 5 ++ tests/v2_x/test_run_actions.py | 27 ++++++++++- 7 files changed, 89 insertions(+), 8 deletions(-) diff --git a/nemoguardrails/actions/action_dispatcher.py b/nemoguardrails/actions/action_dispatcher.py index 5e4a09386..3fe066b4b 100644 --- a/nemoguardrails/actions/action_dispatcher.py +++ b/nemoguardrails/actions/action_dispatcher.py @@ -33,6 +33,8 @@ class ActionDispatcher: + """Manages the execution and life time of local actions.""" + def __init__( self, load_all_actions: bool = True, diff --git a/nemoguardrails/colang/v2_x/runtime/flows.py b/nemoguardrails/colang/v2_x/runtime/flows.py index 88847eecc..5e970e292 100644 --- a/nemoguardrails/colang/v2_x/runtime/flows.py +++ b/nemoguardrails/colang/v2_x/runtime/flows.py @@ -346,7 +346,7 @@ def start_event(self, _args: dict) -> ActionEvent: def change_event(self, args: dict) -> ActionEvent: """Changes a parameter of a started action.""" return ActionEvent( - name=f"Change{self.name}", arguments=args["arguments"], action_uid=self.uid + name=f"Change{self.name}", arguments=args, action_uid=self.uid ) def stop_event(self, _args: dict) -> ActionEvent: diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index c43522432..a6fef27aa 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -16,6 +16,7 @@ import inspect import logging import re +import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from urllib.parse import urljoin @@ -87,7 +88,7 @@ def send_action_updated_event( Send an Action*Updated event. Args: - event_name (str): The name of the action event + event_name (str): The name of the action event, e.g. `Attention` for AttentionUserActionUpdated args (Optional[dict]): An optional dictionary with the event arguments """ @@ -100,6 +101,17 @@ def send_action_updated_event( action_event.to_umim_event(self._config.event_source_uid) ) + async def wait_for_change_action_event( + self, timeout: Optional[float] = None + ) -> Optional[dict]: + """ + Wait for new Change*Action event. + + Args: + timeout (Optional[float]): The time to wait for the new event before it continues + """ + return await self.wait_for_event(self._action.change_event({}).name, timeout) + def send_event(self, event_name: str, args: Optional[dict] = None) -> None: """ Send any event. @@ -113,11 +125,37 @@ def send_event(self, event_name: str, args: Optional[dict] = None) -> None: event.to_umim_event(self._config.event_source_uid) ) + async def wait_for_event( + self, event_name: Optional[str] = None, timeout: Optional[float] = None + ) -> Optional[dict]: + """ + Wait for next new input event to process. + + Args: + event_name (Optional[str]): Optional event name to filter for, if None all events will be received + timeout (Optional[float]): The time to wait for new events before it continues + """ + keep_waiting = True + start_time = time.time() + while keep_waiting: + try: + # Check cumulative waiting time + if timeout and time.time() - start_time > timeout: + raise asyncio.TimeoutError() + # Wait for next event + event = await asyncio.wait_for(self._event_input_queue.get(), timeout) + if event_name is None or event["type"] == event_name: + return event + except asyncio.TimeoutError: + # Timeout occurred, stop consuming + keep_waiting = False + return None + async def wait_for_events( self, event_name: Optional[str] = None, timeout: Optional[float] = None ) -> List[dict]: """ - Waits for new input events to process. + Wait for new input events to process. Args: event_name (Optional[str]): Optional event name to filter for, if None all events will be received @@ -125,8 +163,12 @@ async def wait_for_events( """ events: List[dict] = [] keep_waiting = True + start_time = time.time() while keep_waiting: try: + # Check cumulative waiting time + if timeout and time.time() - start_time > timeout: + raise asyncio.TimeoutError() # Wait for new events event = await asyncio.wait_for(self._event_input_queue.get(), timeout) # Gather all new events diff --git a/tests/test_configs/with_custom_async_action_events_v2_x/actions.py b/tests/test_configs/with_custom_async_action_events_v2_x/actions.py index 451b661db..be83f7f43 100644 --- a/tests/test_configs/with_custom_async_action_events_v2_x/actions.py +++ b/tests/test_configs/with_custom_async_action_events_v2_x/actions.py @@ -27,7 +27,6 @@ async def custom_async_test1(event_handler: ActionEventHandler): await asyncio.sleep(1) event_handler.send_event("CustomEventA", {"value": "A"}) events = await event_handler.wait_for_events("CustomEventB") - print("-----------> CustomEventB") event_handler.send_event("CustomEventResponse", {"value": events[0]["value"]}) await event_handler.wait_for_events("CustomEventC") await asyncio.sleep(3) @@ -37,6 +36,5 @@ async def custom_async_test1(event_handler: ActionEventHandler): @action(name="CustomAsyncTest2Action", is_system_action=True, execute_async=True) async def custom_async_test2(event_handler: ActionEventHandler): await event_handler.wait_for_events("CustomEventResponse") - print("-----------> CustomEventResponse") await asyncio.sleep(3) event_handler.send_event("CustomEventC", {"value": "C"}) diff --git a/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py b/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py index eda7e1d90..868592317 100644 --- a/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py +++ b/tests/test_configs/with_custom_async_action_lifetime_v2_x/actions.py @@ -20,8 +20,19 @@ @action(name="Test1Action", is_system_action=True, execute_async=True) -async def test1(): +async def test1(event_handler: ActionEventHandler): + event = None + value = None + while event is None: + event = await event_handler.wait_for_change_action_event() + if event: + value = event.get("volume", None) + if value: + break + else: + event = None await asyncio.sleep(1) + event_handler.send_action_updated_event("Volume", {"value": value}) @action(name="Test2Action", is_system_action=True, execute_async=True) diff --git a/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co index f4559baa7..5989e71ab 100644 --- a/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co +++ b/tests/test_configs/with_custom_async_action_lifetime_v2_x/config.co @@ -4,6 +4,11 @@ flow main # Normal action lifetime start Test1Action() as $action1_ref match $action1_ref.Started() + start UtteranceBotAction(script="Started") + send $action1_ref.Change(direction="top") + send $action1_ref.Change(volume=10) + match Test1ActionVolumeUpdated() as $ref + start UtteranceBotAction(script="Volume: {$ref.value}") match $action1_ref.Finished() as $ref start UtteranceBotAction(script="Result: {$ref.is_success}") diff --git a/tests/v2_x/test_run_actions.py b/tests/v2_x/test_run_actions.py index 47ac73ac2..49c5d2ecf 100644 --- a/tests/v2_x/test_run_actions.py +++ b/tests/v2_x/test_run_actions.py @@ -150,7 +150,7 @@ def test_custom_action(): chat << "8" -def test_custom_action_generating_async_events(): +def test_custom_action_async_events(): path = os.path.join(CONFIGS_FOLDER, "with_custom_async_action_events_v2_x") test_script = """ > start @@ -174,5 +174,28 @@ def test_custom_action_generating_async_events(): assert result is None, result +def test_custom_async_action_lifetime(): + path = os.path.join(CONFIGS_FOLDER, "with_custom_async_action_lifetime_v2_x") + test_script = """ + > start + Started + Volume: 10 + Result: True + Result: True + Result: Local action finished with an exception! + Event: StopUtteranceBotAction + Event: StopUtteranceBotAction + Event: StopUtteranceBotAction + Event: StopUtteranceBotAction + """ + + loop = get_or_create_event_loop() + result = loop.run_until_complete( + compare_interaction_with_test_script(test_script, 6.0, colang_path=path) + ) + + assert result is None, result + + if __name__ == "__main__": - test_custom_action_generating_async_events() + test_custom_async_action_lifetime() From c452cc028f90a26fe4d520016d3df92bd1e3b3c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Mon, 16 Dec 2024 16:56:19 +0100 Subject: [PATCH 13/14] Make sure local actions are released at end of main flow --- nemoguardrails/colang/v2_x/runtime/runtime.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index a6fef27aa..3a3f07a88 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -917,9 +917,12 @@ def extend_input_events(events: Sequence[Union[dict, InternalEvent]]): # We cap the recent history to the last 500 state.last_events = state.last_events[-500:] - if state.main_flow_state.status == FlowStatus.FINISHED: + if state.main_flow_state.status == FlowStatus.WAITING: + # Main flow is done, release related local action data log.info("End of story!") - del self.local_actions[main_flow_uid] + for item in self.local_actions[main_flow_uid].action_data.values(): + item.task.cancel() + self.local_actions[main_flow_uid].action_data.clear() # We currently filter out all events related local actions # TODO: Consider if we should expose them all as umim events From 410f9980fb1183b05345340cb2b7c81a78ea2fd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20Sch=C3=BCller?= Date: Wed, 18 Dec 2024 10:48:25 +0100 Subject: [PATCH 14/14] Fix issue with action events with provided action_uid --- nemoguardrails/colang/v2_x/runtime/flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/colang/v2_x/runtime/flows.py b/nemoguardrails/colang/v2_x/runtime/flows.py index 5e970e292..b79c0a53d 100644 --- a/nemoguardrails/colang/v2_x/runtime/flows.py +++ b/nemoguardrails/colang/v2_x/runtime/flows.py @@ -177,7 +177,7 @@ def to_umim_event(self, event_source_uid: Optional[str] = None) -> Dict[str, Any "source_uid", event_source_uid if event_source_uid else "NeMoGuardrails-Colang-2.x", ) - if self.action_uid: + if self.action_uid and "action_uid" not in new_event_args: return new_event_dict( self.name, action_uid=self.action_uid, **new_event_args )