diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 276b83630..6d75a6e88 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -597,7 +597,7 @@ async def _search_flows_index(self, text, max_results): return final_results[0:max_results] @action(is_system_action=True) - async def generate_next_step(self, events: List[dict], llm: Optional[BaseLLM] = None): + async def generate_next_steps(self, events: List[dict], llm: Optional[BaseLLM] = None): """Generate the next step in the current conversation flow. Currently, only generates a next step after a user intent. diff --git a/nemoguardrails/llm/filters.py b/nemoguardrails/llm/filters.py index bd464e523..1c7fae0fa 100644 --- a/nemoguardrails/llm/filters.py +++ b/nemoguardrails/llm/filters.py @@ -81,7 +81,7 @@ def co_v2( "GenerateFlowFromInstructionsAction", "GenerateFlowFromNameAction", "generate_intent_steps_message", - "generate_next_step", + "generate_next_steps", "GenerateUserIntentAction", "GenerateValueAction", "GetLastUserMessageAction", diff --git a/nemoguardrails/rails/llm/llm_flows.co b/nemoguardrails/rails/llm/llm_flows.co index 1753a3727..5237e56d6 100644 --- a/nemoguardrails/rails/llm/llm_flows.co +++ b/nemoguardrails/rails/llm/llm_flows.co @@ -78,7 +78,7 @@ define flow generate next step priority 0.9 user ... - execute generate_next_step + execute generate_next_steps define parallel extension flow generate bot message diff --git a/tests/test_configs/with_actions_override/config.py b/tests/test_configs/with_actions_override/config.py index bca0a04f1..30cd4f9f3 100644 --- a/tests/test_configs/with_actions_override/config.py +++ b/tests/test_configs/with_actions_override/config.py @@ -25,7 +25,7 @@ async def generate_user_intent(): @action(is_system_action=True) -async def generate_next_step(): +async def generate_next_steps(): return ActionResult(events=[{"type": "BotIntent", "intent": "respond to question"}]) @@ -38,5 +38,5 @@ async def generate_bot_message(): def init(app: LLMRails): app.register_action(generate_user_intent) - app.register_action(generate_next_step) + app.register_action(generate_next_steps) app.register_action(generate_bot_message) diff --git a/tests/test_general_instructions.py b/tests/test_general_instructions.py index f8f469613..07867b313 100644 --- a/tests/test_general_instructions.py +++ b/tests/test_general_instructions.py @@ -158,7 +158,7 @@ async def test_search_flows_index_is_none(): @pytest.mark.asyncio -async def test_generate_next_step_empty_event_list(): +async def test_generate_next_steps_empty_event_list(): """Check if we try and search the flows index when None we get None back""" config = RailsConfig( @@ -175,7 +175,7 @@ async def test_generate_next_step_empty_event_list(): ) with pytest.raises(RuntimeError, match="No last user intent found from which to generate next step"): - _ = await actions.generate_next_step(events=[]) + _ = await actions.generate_next_steps(events=[]) # diff --git a/tests/test_task_specific_model.py b/tests/test_task_specific_model.py new file mode 100644 index 000000000..dc4e47732 --- /dev/null +++ b/tests/test_task_specific_model.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.imports import check_optional_dependency +from nemoguardrails.rails.llm.options import GenerationResponse + +has_langchain_openai = check_optional_dependency("langchain_openai") + +has_openai_key = bool(os.getenv("OPENAI_API_KEY")) + +skip_if_no_openai = pytest.mark.skipif( + not (has_langchain_openai and has_openai_key), + reason="Requires langchain_openai and OPENAI_API_KEY environment variable", +) + + +@skip_if_no_openai +def test_task_specific_model_for_generate_user_intent_and_generate_next_steps(): + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot express greeting + + define bot express greeting + "Hello! How can I assist you today?" + """, + yaml_content=""" + models: + - type: main + engine: openai + model: gpt-3.5-turbo-instruct + + - type: generate_user_intent + engine: openai + model: gpt-4o-mini + + - type: generate_next_steps + engine: openai + model: gpt-4o-mini + """, + ) + + rails = LLMRails(config) + + res = rails.generate( + messages=[{"role": "user", "content": "what can you do?"}], + options={"log": {"llm_calls": True}}, + ) + + assert isinstance(res, GenerationResponse) + assert res.log is not None + assert res.log.llm_calls is not None + assert len(res.log.llm_calls) > 0 + + task_specific_tasks = ["generate_user_intent", "generate_next_steps"] + + generate_user_intent_calls = [call for call in res.log.llm_calls if call.task == "generate_user_intent"] + assert len(generate_user_intent_calls) > 0 + for call in generate_user_intent_calls: + assert call.llm_model_name == "gpt-4o-mini" + assert call.llm_provider_name == "openai" + + generate_next_steps_calls = [call for call in res.log.llm_calls if call.task == "generate_next_steps"] + assert len(generate_next_steps_calls) > 0 + for call in generate_next_steps_calls: + assert call.llm_model_name == "gpt-4o-mini" + assert call.llm_provider_name == "openai" + + other_calls = [call for call in res.log.llm_calls if call.task not in task_specific_tasks] + for call in other_calls: + assert call.llm_model_name == "gpt-3.5-turbo-instruct" diff --git a/tests/tracing/spans/test_span_v2_otel_semantics.py b/tests/tracing/spans/test_span_v2_otel_semantics.py index 117b06074..21a5dfa2d 100644 --- a/tests/tracing/spans/test_span_v2_otel_semantics.py +++ b/tests/tracing/spans/test_span_v2_otel_semantics.py @@ -121,7 +121,7 @@ def test_llm_span_has_complete_attributes(self): rail = ActivatedRail( type="dialog", - name="generate_next_step", + name="generate_next_steps", started_at=1.0, finished_at=2.0, duration=1.0,