Skip to content

Feature/python action events #907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions nemoguardrails/actions/action_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@


class ActionDispatcher:
"""Manages the execution and life time of local actions."""

def __init__(
self,
load_all_actions: bool = True,
Expand All @@ -51,7 +53,8 @@ def __init__(
"""
log.info("Initializing action dispatcher")

self._registered_actions = {}
# Dictionary with all registered actions
self._registered_actions: dict = {}

if load_all_actions:
# TODO: check for better way to find actions dir path or use constants.py
Expand Down Expand Up @@ -87,7 +90,7 @@ def __init__(
for import_path in import_paths:
self.load_actions_from_path(Path(import_path.strip()))

log.info(f"Registered Actions :: {sorted(self._registered_actions.keys())}")
log.info("Registered Actions :: %s", sorted(self._registered_actions.keys()))
log.info("Action dispatcher initialized")

@property
Expand Down Expand Up @@ -181,7 +184,7 @@ def get_action(self, name: str) -> callable:

async def execute_action(
self, action_name: str, params: Dict[str, Any]
) -> Tuple[Union[str, Dict[str, Any]], str]:
) -> Tuple[Optional[Union[str, Dict[str, Any]]], str]:
"""Execute a registered action.

Args:
Expand All @@ -195,7 +198,7 @@ async def execute_action(
action_name = self._normalize_action_name(action_name)

if action_name in self._registered_actions:
log.info(f"Executing registered action: {action_name}")
log.info("Executing registered action: %s", action_name)
fn = self._registered_actions.get(action_name, None)

# Actions that are registered as classes are initialized lazy, when
Expand All @@ -214,7 +217,7 @@ async def execute_action(
result = await result
else:
log.warning(
f"Synchronous action `{action_name}` has been called."
"Synchronous action `%s` has been called.", action_name
)

elif isinstance(fn, Chain):
Expand Down Expand Up @@ -256,15 +259,12 @@ async def execute_action(
filtered_params = {
k: v
for k, v in params.items()
if k not in ["state", "events", "llm"]
if k not in ["state", "events", "llm", "event_handler"]
}
log.warning(
"Error while execution '%s' with parameters '%s': %s",
action_name,
filtered_params,
e,
msg = (
f"Exception while execution '{action_name}' with parameters '{filtered_params}'",
)
log.exception(e)
raise Exception(f"{msg}: {e}") from e

return None, "failed"

Expand Down
4 changes: 3 additions & 1 deletion nemoguardrails/cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,9 @@ async def _process_input_events():
chat_state.input_events = []
else:
chat_state.waiting_user_input = True
await enable_input.wait()
# NOTE: We should never disable the user input since we can have
# async Python actions running in parallel
# await enable_input.wait()

user_message: str = await chat_state.session.prompt_async(
HTML("<prompt>\n> </prompt>"),
Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/colang/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
from abc import abstractmethod
from typing import Any, Callable, List, Optional, Tuple
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(self, config: RailsConfig, verbose: bool = False):

# A set of watchers that are notified every time an event is processed.
# Used mainly for reporting the progress to the CLI.
self.watchers = []
self.watchers: List = []

# The maximum number of events to be processed in a processing loop
self.max_events = 500
Expand Down
14 changes: 11 additions & 3 deletions nemoguardrails/colang/v1_0/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from langchain.chains.base import Chain

from nemoguardrails.actions.actions import ActionResult
from nemoguardrails.actions.llm.utils import LLMCallException
from nemoguardrails.colang import parse_colang_file
from nemoguardrails.colang.runtime import Runtime
from nemoguardrails.colang.v1_0.runtime.flows import (
Expand Down Expand Up @@ -360,9 +361,16 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
kwargs["llm"] = self.registered_action_params[f"{action_name}_llm"]

log.info("Executing action :: %s", action_name)
result, status = await self.action_dispatcher.execute_action(
action_name, kwargs
)
try:
result, status = await self.action_dispatcher.execute_action(
action_name, kwargs
)
except LLMCallException as e:
raise e
except Exception as e:
result = None
status = "failed"
log.exception(e)

# If the action execution failed, we return a hardcoded message
if status == "failed":
Expand Down
68 changes: 63 additions & 5 deletions nemoguardrails/colang/v2_x/runtime/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Deque, Dict, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
ClassVar,
Deque,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

from dataclasses_json import dataclass_json

Expand All @@ -33,7 +44,7 @@
FlowReturnMemberDef,
)
from nemoguardrails.colang.v2_x.runtime.errors import ColangSyntaxError
from nemoguardrails.utils import new_readable_uuid, new_uuid
from nemoguardrails.utils import new_event_dict, new_readable_uuid, new_uuid

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,6 +119,15 @@ def from_umim_event(cls, event: dict) -> Event:
)
return new_event

def to_umim_event(self, event_source_uid: Optional[str] = None) -> Dict[str, Any]:
"""Return a umim event dictionary."""
new_event_args = dict(self.arguments)
new_event_args.setdefault(
"source_uid",
event_source_uid if event_source_uid else "NeMoGuardrails-Colang-2.x",
)
return new_event_dict(self.name, **new_event_args)

# Expose all event parameters as attributes of the event
def __getattr__(self, name):
if (
Expand Down Expand Up @@ -150,6 +170,20 @@ def from_umim_event(cls, event: dict) -> ActionEvent:
new_event.action_uid = event["action_uid"]
return new_event

def to_umim_event(self, event_source_uid: Optional[str] = None) -> Dict[str, Any]:
"""Return a umim event dictionary."""
new_event_args = dict(self.arguments)
new_event_args.setdefault(
"source_uid",
event_source_uid if event_source_uid else "NeMoGuardrails-Colang-2.x",
)
if self.action_uid and "action_uid" not in new_event_args:
return new_event_dict(
self.name, action_uid=self.action_uid, **new_event_args
)
else:
return new_event_dict(self.name, **new_event_args)


class ActionStatus(Enum):
"""The status of an action."""
Expand All @@ -176,19 +210,43 @@ class Action:
"Stop": "stop_event",
}

# List of umim specific parameters
_umim_parameters: ClassVar[List[str]] = [
"type",
"uid",
"event_created_at",
"source_uid",
"action_uid",
"action_info_modality",
"action_info_modality_policy",
"action_finished_at",
]

@classmethod
def from_event(cls, event: ActionEvent) -> Optional[Action]:
"""Returns the action if event name conforms with UMIM convention."""
assert event.action_uid is not None
for name in cls._event_name_map:
if name in event.name:
action = Action(event.name.replace(name, ""), {})
action_name: str
if name == "Updated":
index = event.name.find("Action") + 6
action_name = event.name[:index]
else:
action_name = event.name.replace(name, "")
action = Action(action_name, {})
action.uid = event.action_uid
action.status = (
ActionStatus.STARTED
if name != "Finished"
else ActionStatus.FINISHED
)
if name == "Start":
action.start_event_arguments = {
key: event.arguments[key]
for key in event.arguments
if key not in cls._umim_parameters
}
return action
return None

Expand Down Expand Up @@ -288,7 +346,7 @@ def start_event(self, _args: dict) -> ActionEvent:
def change_event(self, args: dict) -> ActionEvent:
"""Changes a parameter of a started action."""
return ActionEvent(
name=f"Change{self.name}", arguments=args["arguments"], action_uid=self.uid
name=f"Change{self.name}", arguments=args, action_uid=self.uid
)

def stop_event(self, _args: dict) -> ActionEvent:
Expand Down Expand Up @@ -355,7 +413,7 @@ class FlowConfig:
id: str

# The sequence of elements that compose the flow.
elements: List[ElementType]
elements: Sequence[ElementType]

# The flow parameters
parameters: List[FlowParamDef]
Expand Down
Loading
Loading