Skip to content

Commit 912cdea

Browse files
authored
chore(llm): deprecate nemoguardrails.llm.params module in favor of direct parameter passing (#1471)
1 parent aafd733 commit 912cdea

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

nemoguardrails/llm/params.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,48 @@
1717
Module for providing a context manager to temporarily adjust parameters of a language model.
1818
1919
Also allows registration of custom parameter managers for different language model types.
20+
21+
.. deprecated:: 0.17.0
22+
This module is deprecated and will be removed in version 0.19.0.
23+
Instead of using the context manager approach, pass parameters directly to `llm_call()`
24+
using the `llm_params` argument:
25+
26+
Old way (deprecated):
27+
from nemoguardrails.llm.params import llm_params
28+
with llm_params(llm, temperature=0.7):
29+
result = await llm_call(llm, prompt)
30+
31+
New way (recommended):
32+
result = await llm_call(llm, prompt, llm_params={"temperature": 0.7})
33+
34+
See: https://github.com/NVIDIA/NeMo-Guardrails/issues/1387
2035
"""
2136

2237
import logging
38+
import warnings
2339
from typing import Dict, Type
2440

2541
from langchain.base_language import BaseLanguageModel
2642

2743
log = logging.getLogger(__name__)
2844

45+
_DEPRECATION_MESSAGE = (
46+
"The nemoguardrails.llm.params module is deprecated and will be removed in version 0.19.0. "
47+
"Instead of using llm_params context manager, pass parameters directly to llm_call() "
48+
"using the llm_params argument. "
49+
"See: https://github.com/NVIDIA/NeMo-Guardrails/issues/1387"
50+
)
51+
2952

3053
class LLMParams:
31-
"""Context manager to temporarily modify the parameters of a language model."""
54+
"""Context manager to temporarily modify the parameters of a language model.
55+
56+
.. deprecated:: 0.17.0
57+
Use llm_call() with llm_params argument instead.
58+
"""
3259

3360
def __init__(self, llm: BaseLanguageModel, **kwargs):
61+
warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
3462
self.llm = llm
3563
self.altered_params = kwargs
3664
self.original_params = {}
@@ -82,13 +110,22 @@ def __exit__(self, type, value, traceback):
82110

83111

84112
def register_param_manager(llm_type: Type[BaseLanguageModel], manager: Type[LLMParams]):
85-
"""Register a parameter manager."""
113+
"""Register a parameter manager.
114+
115+
.. deprecated:: 0.17.0
116+
This function is deprecated and will be removed in version 0.19.0.
117+
"""
118+
warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
86119
_param_managers[llm_type] = manager
87120

88121

89122
def llm_params(llm: BaseLanguageModel, **kwargs):
90-
"""Returns a parameter manager for the given language model."""
123+
"""Returns a parameter manager for the given language model.
91124
125+
.. deprecated:: 0.17.0
126+
Use llm_call() with llm_params argument instead.
127+
"""
128+
warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
92129
_llm_params = _param_managers.get(llm.__class__, LLMParams)
93130

94131
return _llm_params(llm, **kwargs)

tests/test_llm_params.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import unittest
17+
import warnings
1718
from typing import Any, Dict
1819
from unittest.mock import AsyncMock, MagicMock
1920

@@ -224,6 +225,57 @@ class UnregisteredLLM(BaseModel):
224225
self.assertIsInstance(llm_params(UnregisteredLLM()), LLMParams)
225226

226227

228+
class TestLLMParamsDeprecation(unittest.TestCase):
229+
"""Test deprecation warnings for llm_params module."""
230+
231+
def test_llm_params_function_raises_deprecation_warning(self):
232+
"""Test that llm_params function raises DeprecationWarning."""
233+
llm = FakeLLM(param3="value3", model_kwargs={"param1": "value1"})
234+
235+
with warnings.catch_warnings(record=True) as w:
236+
warnings.simplefilter("always")
237+
with llm_params(llm, param1="new_value1"):
238+
pass
239+
240+
self.assertGreaterEqual(len(w), 1)
241+
self.assertTrue(
242+
any(issubclass(warning.category, DeprecationWarning) for warning in w)
243+
)
244+
self.assertTrue(
245+
any(
246+
"0.19.0" in str(warning.message)
247+
and "llm_call()" in str(warning.message)
248+
for warning in w
249+
)
250+
)
251+
252+
def test_llm_params_class_raises_deprecation_warning(self):
253+
"""Test that LLMParams class raises DeprecationWarning."""
254+
llm = FakeLLM(param3="value3", model_kwargs={"param1": "value1"})
255+
256+
with warnings.catch_warnings(record=True) as w:
257+
warnings.simplefilter("always")
258+
params = LLMParams(llm, param1="new_value1")
259+
260+
self.assertGreaterEqual(len(w), 1)
261+
self.assertTrue(issubclass(w[0].category, DeprecationWarning))
262+
self.assertIn("0.19.0", str(w[0].message))
263+
264+
def test_register_param_manager_raises_deprecation_warning(self):
265+
"""Test that register_param_manager function raises DeprecationWarning."""
266+
267+
class CustomLLMParams(LLMParams):
268+
pass
269+
270+
with warnings.catch_warnings(record=True) as w:
271+
warnings.simplefilter("always")
272+
register_param_manager(FakeLLM, CustomLLMParams)
273+
274+
self.assertGreaterEqual(len(w), 1)
275+
self.assertTrue(issubclass(w[0].category, DeprecationWarning))
276+
self.assertIn("0.19.0", str(w[0].message))
277+
278+
227279
class TestLLMParamsMigration(unittest.TestCase):
228280
"""Test migration from context manager to direct parameter passing."""
229281

tests/test_llm_params_e2e.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import os
1919
import tempfile
20+
import warnings
2021
from pathlib import Path
2122

2223
import pytest

0 commit comments

Comments
 (0)