diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a5268ed5..48d882884 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,14 @@ repos: args: - --license-filepath - LICENSE.md - + - repo: local + hooks: + - id: pyright + name: pyright + entry: poetry run pyright + language: system + types: [python] + pass_filenames: false # Deactivating this for now. # - repo: https://github.com/pycqa/pylint # rev: v2.17.0 diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 74fa763c5..377b0bc5e 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -82,7 +82,7 @@ class LLMGenerationActions: def __init__( self, config: RailsConfig, - llm: Union[BaseLLM, BaseChatModel], + llm: Optional[Union[BaseLLM, BaseChatModel]], llm_task_manager: LLMTaskManager, get_embedding_search_provider_instance: Callable[ [Optional[EmbeddingSearchProvider]], EmbeddingsIndex diff --git a/nemoguardrails/context.py b/nemoguardrails/context.py index 0659faafb..2e7d34b82 100644 --- a/nemoguardrails/context.py +++ b/nemoguardrails/context.py @@ -14,25 +14,45 @@ # limitations under the License. import contextvars -from typing import Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None) +from nemoguardrails.logging.explain import LLMCallInfo + +if TYPE_CHECKING: + from nemoguardrails.logging.explain import ExplainInfo + from nemoguardrails.logging.stats import LLMStats + from nemoguardrails.rails.llm.options import GenerationOptions + from nemoguardrails.streaming import StreamingHandler + +streaming_handler_var: contextvars.ContextVar[ + Optional["StreamingHandler"] +] = contextvars.ContextVar("streaming_handler", default=None) # The object that holds additional explanation information. -explain_info_var = contextvars.ContextVar("explain_info", default=None) +explain_info_var: contextvars.ContextVar[ + Optional["ExplainInfo"] +] = contextvars.ContextVar("explain_info", default=None) # The current LLM call. -llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None) +llm_call_info_var: contextvars.ContextVar[ + Optional[LLMCallInfo] +] = contextvars.ContextVar("llm_call_info", default=None) # All the generation options applicable to the current context. -generation_options_var = contextvars.ContextVar("generation_options", default=None) +generation_options_var: contextvars.ContextVar[ + Optional["GenerationOptions"] +] = contextvars.ContextVar("generation_options", default=None) # The stats about the LLM calls. -llm_stats_var = contextvars.ContextVar("llm_stats", default=None) +llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar( + "llm_stats", default=None +) # The raw LLM request that comes from the user. # This is used in passthrough mode. -raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None) +raw_llm_request: contextvars.ContextVar[ + Optional[Union[str, List[Dict[str, Any]]]] +] = contextvars.ContextVar("raw_llm_request", default=None) reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( "reasoning_trace", default=None diff --git a/nemoguardrails/rails/llm/buffer.py b/nemoguardrails/rails/llm/buffer.py index 30e48c4e3..fdbd5ba08 100644 --- a/nemoguardrails/rails/llm/buffer.py +++ b/nemoguardrails/rails/llm/buffer.py @@ -138,7 +138,8 @@ async def process_stream( ... print(f"Processing: {context_formatted}") ... print(f"User: {user_formatted}") """ - ... + raise NotImplementedError + yield async def __call__(self, streaming_handler) -> AsyncGenerator[ChunkBatch, None]: """Callable interface that delegates to process_stream. diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index eac54fd37..6c5073a78 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -487,7 +487,7 @@ class OutputRails(BaseModel): description="The names of all the flows that implement output rails.", ) - streaming: Optional[OutputRailsStreamingConfig] = Field( + streaming: OutputRailsStreamingConfig = Field( default_factory=OutputRailsStreamingConfig, description="Configuration for streaming output rails.", ) @@ -1128,7 +1128,9 @@ def _load_path( # the first .railsignore file found from cwd down to its subdirectories railsignore_path = utils.get_railsignore_path(config_path) - ignore_patterns = utils.get_railsignore_patterns(railsignore_path) + ignore_patterns = ( + utils.get_railsignore_patterns(railsignore_path) if railsignore_path else set() + ) if os.path.isdir(config_path): for root, _, files in os.walk(config_path, followlinks=True): @@ -1245,8 +1247,8 @@ def _parse_colang_files_recursively( current_file, current_path = colang_files[len(parsed_colang_files)] with open(current_path, "r", encoding="utf-8") as f: + content = f.read() try: - content = f.read() _parsed_config = parse_colang_file( current_file, content=content, version=colang_version ) @@ -1748,7 +1750,7 @@ def streaming_supported(self): # if we have output rails streaming enabled # we keep it in case it was needed when we have # support per rails - if self.rails.output.streaming.enabled: + if self.rails.output.streaming and self.rails.output.streaming.enabled: return True return False diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index fb1bcdf19..fe56bcf08 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -24,7 +24,18 @@ import threading import time from functools import partial -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Type, Union, cast +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + cast, +) from langchain_core.language_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM @@ -69,7 +80,11 @@ from nemoguardrails.logging.verbose import set_verbose from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop from nemoguardrails.rails.llm.buffer import get_buffer_strategy -from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig +from nemoguardrails.rails.llm.config import ( + EmbeddingSearchProvider, + OutputRailsStreamingConfig, + RailsConfig, +) from nemoguardrails.rails.llm.options import ( GenerationLog, GenerationOptions, @@ -205,17 +220,18 @@ def __init__( # We check if the configuration or any of the imported ones have config.py modules. config_modules = [] - for _path in list(self.config.imported_paths.values()) + [ - self.config.config_path - ]: + for _path in list( + self.config.imported_paths.values() if self.config.imported_paths else [] + ) + [self.config.config_path]: if _path: filepath = os.path.join(_path, "config.py") if os.path.exists(filepath): filename = os.path.basename(filepath) spec = importlib.util.spec_from_file_location(filename, filepath) - config_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config_module) - config_modules.append(config_module) + if spec and spec.loader: + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + config_modules.append(config_module) # First, we initialize the runtime. if config.colang_version == "1.0": @@ -393,8 +409,8 @@ def _configure_main_llm_streaming( if not self.config.streaming: return - if "streaming" in llm.model_fields: - llm.streaming = True + if hasattr(llm, "streaming"): + setattr(llm, "streaming", True) self.main_llm_supports_streaming = True else: self.main_llm_supports_streaming = False @@ -760,6 +776,19 @@ async def generate_async( The completion (when a prompt is provided) or the next message. System messages are not yet supported.""" + # convert options to gen_options of type GenerationOptions + gen_options: Optional[GenerationOptions] = None + + if prompt is None and messages is None: + raise ValueError("Either prompt or messages must be provided.") + + if prompt is not None and messages is not None: + raise ValueError("Only one of prompt or messages can be provided.") + + if prompt is not None: + # Currently, we transform the prompt request into a single turn conversation + messages = [{"role": "user", "content": prompt}] + # If a state object is specified, then we switch to "generation options" mode. # This is because we want the output to be a GenerationResponse which will contain # the output state. @@ -769,14 +798,25 @@ async def generate_async( state = json_to_state(state["state"]) if options is None: - options = GenerationOptions() - - # We allow options to be specified both as a dict and as an object. - if options and isinstance(options, dict): - options = GenerationOptions(**options) + gen_options = GenerationOptions() + elif isinstance(options, dict): + gen_options = GenerationOptions(**options) + else: + gen_options = options + else: + # We allow options to be specified both as a dict and as an object. + if options and isinstance(options, dict): + gen_options = GenerationOptions(**options) + elif isinstance(options, GenerationOptions): + gen_options = options + elif options is None: + gen_options = None + else: + raise TypeError("options must be a dict or GenerationOptions") # Save the generation options in the current async context. - generation_options_var.set(options) + # At this point, gen_options is either None or GenerationOptions + generation_options_var.set(gen_options) if streaming_handler: streaming_handler_var.set(streaming_handler) @@ -786,26 +826,25 @@ async def generate_async( # requests are made. self.explain_info = self._ensure_explain_info() - if prompt is not None: - # Currently, we transform the prompt request into a single turn conversation - messages = [{"role": "user", "content": prompt}] - raw_llm_request.set(prompt) - else: - raw_llm_request.set(messages) + raw_llm_request.set(messages) # If we have generation options, we also add them to the context - if options: + if gen_options: messages = [ - {"role": "context", "content": {"generation_options": options.dict()}} - ] + messages + { + "role": "context", + "content": {"generation_options": gen_options.model_dump()}, + } + ] + (messages or []) # If the last message is from the assistant, rather than the user, then # we move that to the `$bot_message` variable. This is to enable a more # convenient interface. (only when dialog rails are disabled) if ( - messages[-1]["role"] == "assistant" - and options - and options.rails.dialog is False + messages + and messages[-1]["role"] == "assistant" + and gen_options + and gen_options.rails.dialog is False ): # We already have the first message with a context update, so we use that messages[0]["content"]["bot_message"] = messages[-1]["content"] @@ -822,7 +861,7 @@ async def generate_async( processing_log = [] # The array of events corresponding to the provided sequence of messages. - events = self._get_events_for_messages(messages, state) + events = self._get_events_for_messages(messages, state) # type: ignore if self.config.colang_version == "1.0": # If we had a state object, we also need to prepend the events from the state. @@ -846,10 +885,10 @@ async def generate_async( # Push an error chunk instead of None. error_message = str(e) error_dict = extract_error_json(error_message) - error_payload = json.dumps(error_dict) + error_payload: str = json.dumps(error_dict) await streaming_handler.push_chunk(error_payload) # push a termination signal - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore # Re-raise the exact exception raise else: @@ -920,7 +959,7 @@ async def generate_async( response_events.append(event) if exception: - new_message = {"role": "exception", "content": exception} + new_message: dict = {"role": "exception", "content": exception} else: # Ensure all items in responses are strings @@ -928,7 +967,7 @@ async def generate_async( str(response) if not isinstance(response, str) else response for response in responses ] - new_message = {"role": "assistant", "content": "\n".join(responses)} + new_message: dict = {"role": "assistant", "content": "\n".join(responses)} if response_tool_calls: new_message["tool_calls"] = response_tool_calls if response_events: @@ -941,7 +980,7 @@ async def generate_async( # If a state object is not used, then we use the implicit caching if state is None: # Save the new events in the history and update the cache - cache_key = get_history_cache_key(messages + [new_message]) + cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore self.events_history_cache[cache_key] = events else: output_state = {"events": events} @@ -964,38 +1003,34 @@ async def generate_async( streaming_handler = streaming_handler_var.get() if streaming_handler: # print("Closing the stream handler explicitly") - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore # IF tracing is enabled we need to set GenerationLog attrs original_log_options = None if self.config.tracing.enabled: - if options is None: - options = GenerationOptions() + if gen_options is None: + gen_options = GenerationOptions() else: - # create a copy of the options to avoid modifying the original - if isinstance(options, GenerationOptions): - options = options.model_copy(deep=True) - else: - # If options is a dict, convert it to GenerationOptions - options = GenerationOptions(**options) - original_log_options = options.log.model_copy(deep=True) + # create a copy of the gen_options to avoid modifying the original + gen_options = gen_options.model_copy(deep=True) + original_log_options = gen_options.log.model_copy(deep=True) # enable log options # it is aggressive, but these are required for tracing if ( - not options.log.activated_rails - or not options.log.llm_calls - or not options.log.internal_events + not gen_options.log.activated_rails + or not gen_options.log.llm_calls + or not gen_options.log.internal_events ): - options.log.activated_rails = True - options.log.llm_calls = True - options.log.internal_events = True + gen_options.log.activated_rails = True + gen_options.log.llm_calls = True + gen_options.log.internal_events = True tool_calls = extract_tool_calls_from_events(new_events) llm_metadata = get_and_clear_response_metadata_contextvar() # If we have generation options, we prepare a GenerationResponse instance. - if options: + if gen_options: # If a prompt was used, we only need to return the content of the message. if prompt: res = GenerationResponse(response=new_message["content"]) @@ -1004,11 +1039,15 @@ async def generate_async( if reasoning_trace := get_and_clear_reasoning_trace_contextvar(): if prompt: - res.response = reasoning_trace + res.response + # For prompt mode, response should be a string + if isinstance(res.response, str): + res.response = reasoning_trace + res.response else: - res.response[0]["content"] = ( - reasoning_trace + res.response[0]["content"] - ) + # For message mode, response should be a list + if isinstance(res.response, list) and len(res.response) > 0: + res.response[0]["content"] = ( + reasoning_trace + res.response[0]["content"] + ) if tool_calls: res.tool_calls = tool_calls @@ -1018,13 +1057,12 @@ async def generate_async( if self.config.colang_version == "1.0": # If output variables are specified, we extract their values - if options.output_vars: + if gen_options and gen_options.output_vars: context = compute_context(events) - if isinstance(options.output_vars, list): + output_vars = gen_options.output_vars + if isinstance(output_vars, list): # If we have only a selection of keys, we filter to only that. - res.output_data = { - k: context.get(k) for k in options.output_vars - } + res.output_data = {k: context.get(k) for k in output_vars} else: # Otherwise, we return the full context res.output_data = context @@ -1032,37 +1070,40 @@ async def generate_async( _log = compute_generation_log(processing_log) # Include information about activated rails and LLM calls if requested - if options.log.activated_rails or options.log.llm_calls: + log_options = gen_options.log if gen_options else None + if log_options and ( + log_options.activated_rails or log_options.llm_calls + ): res.log = GenerationLog() # We always include the stats res.log.stats = _log.stats - if options.log.activated_rails: + if log_options.activated_rails: res.log.activated_rails = _log.activated_rails - if options.log.llm_calls: + if log_options.llm_calls: res.log.llm_calls = [] for activated_rail in _log.activated_rails: for executed_action in activated_rail.executed_actions: res.log.llm_calls.extend(executed_action.llm_calls) # Include internal events if requested - if options.log.internal_events: + if log_options and log_options.internal_events: if res.log is None: res.log = GenerationLog() res.log.internal_events = new_events # Include the Colang history if requested - if options.log.colang_history: + if log_options and log_options.colang_history: if res.log is None: res.log = GenerationLog() res.log.colang_history = get_colang_history(events) # Include the raw llm output if requested - if options.llm_output: + if gen_options and gen_options.llm_output: # Currently, we include the output from the generation LLM calls. for activated_rail in _log.activated_rails: if activated_rail.type == "generation": @@ -1070,22 +1111,23 @@ async def generate_async( for llm_call in executed_action.llm_calls: res.llm_output = llm_call.raw_response else: - if options.output_vars: + if gen_options and gen_options.output_vars: raise ValueError( "The `output_vars` option is not supported for Colang 2.0 configurations." ) - if ( - options.log.activated_rails - or options.log.llm_calls - or options.log.internal_events - or options.log.colang_history + log_options = gen_options.log if gen_options else None + if log_options and ( + log_options.activated_rails + or log_options.llm_calls + or log_options.internal_events + or log_options.colang_history ): raise ValueError( "The `log` option is not supported for Colang 2.0 configurations." ) - if options.llm_output: + if gen_options and gen_options.llm_output: raise ValueError( "The `llm_output` option is not supported for Colang 2.0 configurations." ) @@ -1127,12 +1169,14 @@ async def generate_async( ): res.log = None else: - if not original_log_options.internal_events: - res.log.internal_events = [] - if not original_log_options.activated_rails: - res.log.activated_rails = [] - if not original_log_options.llm_calls: - res.log.llm_calls = [] + # Ensure res.log exists before setting attributes + if res.log is not None: + if not original_log_options.internal_events: + res.log.internal_events = [] + if not original_log_options.activated_rails: + res.log.activated_rails = [] + if not original_log_options.llm_calls: + res.log.llm_calls = [] return res else: @@ -1161,9 +1205,13 @@ def stream_async( # if an external generator is provided, use it directly if generator: - if self.config.rails.output.streaming.enabled: + if ( + self.config.rails.output.streaming + and self.config.rails.output.streaming.enabled + ): return self._run_output_rails_in_streaming( streaming_handler=generator, + output_rails_streaming_config=self.config.rails.output.streaming, messages=messages, prompt=prompt, ) @@ -1194,7 +1242,7 @@ async def _generation_task(): error_dict = extract_error_json(error_message) error_payload = json.dumps(error_dict) await streaming_handler.push_chunk(error_payload) - await streaming_handler.push_chunk(END_OF_STREAM) + await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore task = asyncio.create_task(_generation_task()) @@ -1212,10 +1260,14 @@ def task_done_callback(task): # when we have output rails we wrap the streaming handler # if len(self.config.rails.output.flows) > 0: # - if self.config.rails.output.streaming.enabled: + if ( + self.config.rails.output.streaming + and self.config.rails.output.streaming.enabled + ): # returns an async generator return self._run_output_rails_in_streaming( streaming_handler=streaming_handler, + output_rails_streaming_config=self.config.rails.output.streaming, messages=messages, prompt=prompt, ) @@ -1367,7 +1419,7 @@ def process_events( self.process_events_async(events, state, blocking) ) - def register_action(self, action: callable, name: Optional[str] = None) -> Self: + def register_action(self, action: Callable, name: Optional[str] = None) -> Self: """Register a custom action for the rails configuration.""" self.runtime.register_action(action, name) return self @@ -1377,12 +1429,12 @@ def register_action_param(self, name: str, value: Any) -> Self: self.runtime.register_action_param(name, value) return self - def register_filter(self, filter_fn: callable, name: Optional[str] = None) -> Self: + def register_filter(self, filter_fn: Callable, name: Optional[str] = None) -> Self: """Register a custom filter for the rails configuration.""" self.runtime.llm_task_manager.register_filter(filter_fn, name) return self - def register_output_parser(self, output_parser: callable, name: str) -> Self: + def register_output_parser(self, output_parser: Callable, name: str) -> Self: """Register a custom output parser for the rails configuration.""" self.runtime.llm_task_manager.register_output_parser(output_parser, name) return self @@ -1427,6 +1479,8 @@ def register_embedding_provider( def explain(self) -> ExplainInfo: """Helper function to return the latest ExplainInfo object.""" + if self.explain_info is None: + self.explain_info = self._ensure_explain_info() return self.explain_info def __getstate__(self): @@ -1442,6 +1496,7 @@ def __setstate__(self, state): async def _run_output_rails_in_streaming( self, streaming_handler: AsyncIterator[str], + output_rails_streaming_config: OutputRailsStreamingConfig, prompt: Optional[str] = None, messages: Optional[List[dict]] = None, stream_first: Optional[bool] = None, @@ -1544,7 +1599,6 @@ def _prepare_params( **action_params, } - output_rails_streaming_config = self.config.rails.output.streaming buffer_strategy = get_buffer_strategy(output_rails_streaming_config) output_rails_flows_id = self.config.rails.output.flows stream_first = stream_first or output_rails_streaming_config.stream_first @@ -1619,9 +1673,10 @@ def _prepare_params( pass else: # if there are any stop events, content was blocked or internal error occurred - if result.events: + result_events = getattr(result, "events", None) + if result_events: # extract the flow info from the first stop event - stop_event = result.events[0] + stop_event = result_events[0] blocked_flow = stop_event.get("flow_id", "output rails") error_type = stop_event.get("error_type") diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index dd9f87099..67bf9c76a 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Generation options give more control over the generation and the result. +"""Generation options give more control over the generation and the result. For example, to run only the input rails:: @@ -76,6 +76,7 @@ # {..., log: {"llm_calls": [...]}} """ + from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field, root_validator @@ -156,7 +157,7 @@ class GenerationOptions(BaseModel): default=None, description="Additional parameters that should be used for the LLM call", ) - llm_output: Optional[bool] = Field( + llm_output: bool = Field( default=False, description="Whether the response should also include any custom LLM output.", ) @@ -233,7 +234,7 @@ class ActivatedRail(BaseModel): ) decisions: List[str] = Field( default_factory=list, - descriptino="A sequence of decisions made by the rail, e.g., 'bot refuse to respond', 'stop', 'continue'.", + description="A sequence of decisions made by the rail, e.g., 'bot refuse to respond', 'stop', 'continue'.", ) executed_actions: List[ExecutedAction] = Field( default_factory=list, description="The list of actions executed by the rail." @@ -327,7 +328,7 @@ def print_summary(self): duration = 0 print(f"- Total time: {self.stats.total_duration:.2f}s") - if self.stats.input_rails_duration: + if self.stats.input_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.input_rails_duration / self.stats.total_duration, 2 ) @@ -335,7 +336,7 @@ def print_summary(self): duration += self.stats.input_rails_duration print(f" - [{self.stats.input_rails_duration:.2f}s][{_pc}%]: INPUT Rails") - if self.stats.dialog_rails_duration: + if self.stats.dialog_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.dialog_rails_duration / self.stats.total_duration, 2 ) @@ -345,7 +346,7 @@ def print_summary(self): print( f" - [{self.stats.dialog_rails_duration:.2f}s][{_pc}%]: DIALOG Rails" ) - if self.stats.generation_rails_duration: + if self.stats.generation_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.generation_rails_duration / self.stats.total_duration, 2, @@ -356,7 +357,7 @@ def print_summary(self): print( f" - [{self.stats.generation_rails_duration:.2f}s][{_pc}%]: GENERATION Rails" ) - if self.stats.output_rails_duration: + if self.stats.output_rails_duration and self.stats.total_duration: _pc = round( 100 * self.stats.output_rails_duration / self.stats.total_duration, 2 ) @@ -367,12 +368,12 @@ def print_summary(self): f" - [{self.stats.output_rails_duration:.2f}s][{_pc}%]: OUTPUT Rails" ) - processing_overhead = self.stats.total_duration - duration + processing_overhead = (self.stats.total_duration or 0) - duration if processing_overhead >= 0.01: _pc = round(100 - pc, 2) print(f" - [{processing_overhead:.2f}s][{_pc}%]: Processing overhead ") - if self.stats.llm_calls_count > 0: + if self.stats.llm_calls_count: print( f"- {self.stats.llm_calls_count} LLM calls, " f"{self.stats.llm_calls_duration:.2f}s total duration, " @@ -391,7 +392,10 @@ def print_summary(self): for action in activated_rail.executed_actions: llm_calls_count += len(action.llm_calls) llm_calls_durations.extend( - [f"{round(llm_call.duration, 2)}s" for llm_call in action.llm_calls] + [ + f"{round(llm_call.duration or 0, 2)}s" + for llm_call in action.llm_calls + ] ) print( f"- [{activated_rail.duration:.2f}s] {activated_rail.type.upper()} ({activated_rail.name}): " @@ -431,4 +435,6 @@ class GenerationResponse(BaseModel): if __name__ == "__main__": - print(GenerationOptions(**{"rails": {"input": False}})) + print( + GenerationOptions(rails=GenerationRailsOptions(input=False)) + ) # pragma: no cover (Can't run as script for test coverage) diff --git a/nemoguardrails/utils.py b/nemoguardrails/utils.py index a337a978f..bc27a6c74 100644 --- a/nemoguardrails/utils.py +++ b/nemoguardrails/utils.py @@ -375,7 +375,7 @@ def get_railsignore_patterns(railsignore_path: Path) -> Set[str]: return ignored_patterns -def is_ignored_by_railsignore(filename: str, ignore_patterns: str) -> bool: +def is_ignored_by_railsignore(filename: str, ignore_patterns: Set[str]) -> bool: """Verify if a filename should be ignored by a railsignore pattern""" ignore = False diff --git a/poetry.lock b/poetry.lock index eeced1c7e..b5eedf3d0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4298,6 +4298,26 @@ files = [ [package.extras] dev = ["build", "flake8", "mypy", "pytest", "twine"] +[[package]] +name = "pyright" +version = "1.1.405" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a"}, + {file = "pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" +typing-extensions = ">=4.1" + +[package.extras] +all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] +nodejs = ["nodejs-wheel-binaries"] + [[package]] name = "pytest" version = "8.4.1" @@ -6448,4 +6468,4 @@ tracing = ["aiofiles", "opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.9,!=3.9.7,<3.14" -content-hash = "6654d6115d5142024695ff1a736cc3d133842421b1282f5c3ba413b6a0250118" +content-hash = "313705d475a9cb177efa633c193da9315388aa99832b9c5b429fafb5b3da44b0" diff --git a/pyproject.toml b/pyproject.toml index 418616a58..86aa7fcad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,7 +151,11 @@ pytest-profiling = "^1.7.0" yara-python = "^4.5.1" opentelemetry-api = "^1.34.1" opentelemetry-sdk = "^1.34.1" +pyright = "^1.1.405" +# Directories in which to run Pyright type-checking +[tool.pyright] +include = ["nemoguardrails/rails/**"] [tool.poetry.group.docs] optional = true diff --git a/tests/rails/llm/test_config.py b/tests/rails/llm/test_config.py index 7b4a3cfe1..f79dbc0ad 100644 --- a/tests/rails/llm/test_config.py +++ b/tests/rails/llm/test_config.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +from unittest.mock import MagicMock + import pytest +from langchain.llms.base import BaseLLM from pydantic import ValidationError -from nemoguardrails.rails.llm.config import ( - Document, - Instruction, - Model, - RailsConfig, - TaskPrompt, -) +from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt +from nemoguardrails.rails.llm.llmrails import LLMRails def test_task_prompt_valid_content(): @@ -307,3 +306,76 @@ def test_rails_config_none_config_path(): result2 = config3 + config4 assert result2.config_path == "" + + +def test_llm_rails_configure_streaming_with_attr(): + """Check LLM has the streaming attribute set if RailsConfig has it""" + + mock_llm = MagicMock(spec=BaseLLM) + config = RailsConfig( + models=[], + streaming=True, + ) + + rails = LLMRails(config, llm=mock_llm) + setattr(mock_llm, "streaming", None) + rails._configure_main_llm_streaming(llm=mock_llm) + + assert mock_llm.streaming + + +def test_llm_rails_configure_streaming_without_attr(caplog): + """Check LLM has the streaming attribute set if RailsConfig has it""" + + mock_llm = MagicMock(spec=BaseLLM) + config = RailsConfig( + models=[], + streaming=True, + ) + + rails = LLMRails(config, llm=mock_llm) + rails._configure_main_llm_streaming(mock_llm) + + assert caplog.messages[-1] == "Provided main LLM does not support streaming." + + +def test_rails_config_streaming_supported_no_output_flows(): + """Check `streaming_supported` property doesn't depend on RailsConfig.streaming with no output flows""" + + config = RailsConfig( + models=[], + streaming=False, + ) + assert config.streaming_supported + + +def test_rails_config_flows_streaming_supported_true(): + """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" + + rails = { + "output": { + "flows": ["content_safety_check_output"], + "streaming": {"enabled": True}, + } + } + prompts = [{"task": "content safety check output", "content": "..."}] + rails_config = RailsConfig.model_validate( + {"models": [], "rails": rails, "prompts": prompts} + ) + assert rails_config.streaming_supported + + +def test_rails_config_flows_streaming_supported_false(): + """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" + + rails = { + "output": { + "flows": ["content_safety_check_output"], + "streaming": {"enabled": False}, + } + } + prompts = [{"task": "content safety check output", "content": "..."}] + rails_config = RailsConfig.model_validate( + {"models": [], "rails": rails, "prompts": prompts} + ) + assert not rails_config.streaming_supported diff --git a/tests/test_generation_options.py b/tests/test_generation_options.py index 06895aa87..a8aeff02b 100644 --- a/tests/test_generation_options.py +++ b/tests/test_generation_options.py @@ -18,7 +18,11 @@ import pytest from nemoguardrails import LLMRails, RailsConfig -from nemoguardrails.rails.llm.options import GenerationResponse +from nemoguardrails.rails.llm.options import ( + GenerationLog, + GenerationResponse, + GenerationStats, +) from tests.utils import TestChat @@ -313,3 +317,38 @@ def test_only_input_output_validation(): assert res.response == [ {"content": "I'm sorry, I can't respond to that.", "role": "assistant"} ] + + +def test_generation_log_print_summary(capsys): + """Test printing rais stats with dummy data""" + + stats = GenerationStats( + input_rails_duration=1.0, + dialog_rails_duration=2.0, + generation_rails_duration=3.0, + output_rails_duration=4.0, + total_duration=10.0, # Sum of all previous rail durations + llm_calls_duration=8.0, # Less than total duration + llm_calls_count=4, # Input, dialog, generation and output calls + llm_calls_total_prompt_tokens=1000, + llm_calls_total_completion_tokens=2000, + llm_calls_total_tokens=3000, # Sum of prompt and completion tokens + ) + + generation_log = GenerationLog(activated_rails=[], stats=stats) + + generation_log.print_summary() + capture = capsys.readouterr() + capture_lines = capture.out.splitlines() + + # Check the correct times were printed + assert capture_lines[1] == "# General stats" + assert capture_lines[3] == "- Total time: 10.00s" + assert capture_lines[4] == " - [1.00s][10.0%]: INPUT Rails" + assert capture_lines[5] == " - [2.00s][20.0%]: DIALOG Rails" + assert capture_lines[6] == " - [3.00s][30.0%]: GENERATION Rails" + assert capture_lines[7] == " - [4.00s][40.0%]: OUTPUT Rails" + assert ( + capture_lines[8] + == "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens." + ) diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index f97389284..9b8a2b300 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -15,11 +15,13 @@ import os from typing import Any, Dict, List, Optional, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +from langchain_core.language_models import BaseChatModel from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.logging.explain import ExplainInfo from nemoguardrails.rails.llm.config import Model from nemoguardrails.rails.llm.llmrails import get_action_details_from_flow_id from tests.utils import FakeLLM, clean_events, event_sequence_conforms @@ -1170,3 +1172,18 @@ def dummy_parser(text): assert "chained_action" in rails.runtime.action_dispatcher.registered_actions assert "chained_param" in rails.runtime.registered_action_params assert rails.runtime.registered_action_params["chained_param"] == "param_value" + + +def test_explain_calls_ensure_explain_info(): + """Make sure if no `explain_info` attribute is present in LLMRails it's populated with + an empty ExplainInfo object""" + + mock_llm = MagicMock(spec=BaseChatModel) + config = RailsConfig.from_content(config={"models": []}) + rails = LLMRails(config=config, llm=mock_llm) + rails.generate(messages=[{"role": "user", "content": "Hi!"}]) + + rails.explain_info = None + info = rails.explain() + assert info == ExplainInfo() + assert rails.explain_info == ExplainInfo() diff --git a/tests/test_retrieve_relevant_chunks.py b/tests/test_retrieve_relevant_chunks.py index 7d1044661..72258ef48 100644 --- a/tests/test_retrieve_relevant_chunks.py +++ b/tests/test_retrieve_relevant_chunks.py @@ -12,15 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest +from langchain_core.language_models import BaseChatModel from nemoguardrails import LLMRails, RailsConfig from nemoguardrails.kb.kb import KnowledgeBase from tests.utils import TestChat -config = RailsConfig.from_content( +RAILS_CONFIG = RailsConfig.from_content( """ import llm import core @@ -55,7 +56,7 @@ def test_relevant_chunk_inserted_in_prompt(): ] chat = TestChat( - config, + RAILS_CONFIG, llm_completions=[ " user express greeting", ' bot respond to aditional context\nbot action: "Hello is there anything else" ', @@ -70,19 +71,21 @@ def test_relevant_chunk_inserted_in_prompt(): {"role": "user", "content": "Hi!"}, ] - new_message = rails.generate(messages=messages) + before_llm_calls = len(rails.explain().llm_calls) + _ = rails.generate(messages=messages) + after_llm_calls = len(rails.explain().llm_calls) + llm_call_count = after_llm_calls - before_llm_calls info = rails.explain() - assert len(info.llm_calls) == 2 - assert "Test Body" in info.llm_calls[1].prompt - - assert "markdown" in info.llm_calls[1].prompt - assert "context" in info.llm_calls[1].prompt + assert llm_call_count == 2 + assert "Test Body" in info.llm_calls[-1].prompt + assert "markdown" in info.llm_calls[-1].prompt + assert "context" in info.llm_calls[-1].prompt def test_relevant_chunk_inserted_in_prompt_no_kb(): chat = TestChat( - config, + RAILS_CONFIG, llm_completions=[ " user express greeting", ' bot respond to aditional context\nbot action: "Hello is there anything else" ', @@ -92,8 +95,13 @@ def test_relevant_chunk_inserted_in_prompt_no_kb(): messages = [ {"role": "user", "content": "Hi!"}, ] - new_message = rails.generate(messages=messages) + + before_llm_calls = len(rails.explain().llm_calls) + _ = rails.generate(messages=messages) + after_llm_calls = len(rails.explain().llm_calls) + llm_call_count = after_llm_calls - before_llm_calls + info = rails.explain() - assert len(info.llm_calls) == 2 + assert llm_call_count == 2 assert "markdown" not in info.llm_calls[1].prompt assert "context" not in info.llm_calls[1].prompt