Skip to content

Commit fab35c0

Browse files
committed
feat(cache): add caching support for topic safety and content safety output checks
Extends the LLM caching system to support topic safety input checks and content safety output checks. Both actions now cache their results along with LLM stats and metadata to improve performance on repeated queries. Changes - Added caching support to topic_safety_check_input() with cache hit/miss logic - Added caching support to content_safety_check_output() with cache hit/miss logic - Both actions now extract and store LLM metadata alongside stats in cache entries - Added model_caches parameter to both actions for optional cache injection - Comprehensive test coverage for both new caching implementations - Tests verify cache hits, stats restoration, and metadata handling pre-commits
1 parent 8e0f7a0 commit fab35c0

File tree

4 files changed

+246
-4
lines changed

4 files changed

+246
-4
lines changed

nemoguardrails/library/content_safety/actions.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from nemoguardrails.llm.cache.utils import (
2626
CacheEntry,
2727
create_normalized_cache_key,
28+
extract_llm_metadata_for_cache,
2829
extract_llm_stats_for_cache,
2930
get_from_cache_and_restore_stats,
3031
)
@@ -110,6 +111,7 @@ async def content_safety_check_input(
110111
cache_entry: CacheEntry = {
111112
"result": final_result,
112113
"llm_stats": extract_llm_stats_for_cache(),
114+
"llm_metadata": extract_llm_metadata_for_cache(),
113115
}
114116
cache.put(cache_key, cache_entry)
115117
log.debug(f"Content safety result cached for model '{model_name}'")
@@ -139,6 +141,7 @@ async def content_safety_check_output(
139141
llm_task_manager: LLMTaskManager,
140142
model_name: Optional[str] = None,
141143
context: Optional[dict] = None,
144+
model_caches: Optional[Dict[str, CacheInterface]] = None,
142145
**kwargs,
143146
) -> dict:
144147
_MAX_TOKENS = 3
@@ -176,12 +179,22 @@ async def content_safety_check_output(
176179
"bot_response": bot_response,
177180
},
178181
)
182+
179183
stop = llm_task_manager.get_stop_tokens(task=task)
180184
max_tokens = llm_task_manager.get_max_tokens(task=task)
181185

186+
llm_call_info_var.set(LLMCallInfo(task=task))
187+
182188
max_tokens = max_tokens or _MAX_TOKENS
183189

184-
llm_call_info_var.set(LLMCallInfo(task=task))
190+
cache = model_caches.get(model_name) if model_caches else None
191+
192+
if cache:
193+
cache_key = create_normalized_cache_key(check_output_prompt)
194+
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
195+
if cached_result is not None:
196+
log.debug(f"Content safety output cache hit for model '{model_name}'")
197+
return cached_result
185198

186199
result = await llm_call(
187200
llm,
@@ -194,4 +207,16 @@ async def content_safety_check_output(
194207

195208
is_safe, *violated_policies = result
196209

197-
return {"allowed": is_safe, "policy_violations": violated_policies}
210+
final_result = {"allowed": is_safe, "policy_violations": violated_policies}
211+
212+
if cache:
213+
cache_key = create_normalized_cache_key(check_output_prompt)
214+
cache_entry: CacheEntry = {
215+
"result": final_result,
216+
"llm_stats": extract_llm_stats_for_cache(),
217+
"llm_metadata": extract_llm_metadata_for_cache(),
218+
}
219+
cache.put(cache_key, cache_entry)
220+
log.debug(f"Content safety output result cached for model '{model_name}'")
221+
222+
return final_result

nemoguardrails/library/topic_safety/actions.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
from nemoguardrails.actions.actions import action
2222
from nemoguardrails.actions.llm.utils import llm_call
2323
from nemoguardrails.context import llm_call_info_var
24+
from nemoguardrails.llm.cache import CacheInterface
25+
from nemoguardrails.llm.cache.utils import (
26+
CacheEntry,
27+
create_normalized_cache_key,
28+
extract_llm_metadata_for_cache,
29+
extract_llm_stats_for_cache,
30+
get_from_cache_and_restore_stats,
31+
)
2432
from nemoguardrails.llm.filters import to_chat_messages
2533
from nemoguardrails.llm.taskmanager import LLMTaskManager
2634
from nemoguardrails.logging.explain import LLMCallInfo
@@ -35,6 +43,7 @@ async def topic_safety_check_input(
3543
model_name: Optional[str] = None,
3644
context: Optional[dict] = None,
3745
events: Optional[List[dict]] = None,
46+
model_caches: Optional[Dict[str, CacheInterface]] = None,
3847
**kwargs,
3948
) -> dict:
4049
_MAX_TOKENS = 10
@@ -102,11 +111,32 @@ async def topic_safety_check_input(
102111
messages.extend(conversation_history)
103112
messages.append({"type": "user", "content": user_input})
104113

114+
cache = model_caches.get(model_name) if model_caches else None
115+
116+
if cache:
117+
cache_key = create_normalized_cache_key(messages)
118+
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
119+
if cached_result is not None:
120+
log.debug(f"Topic safety cache hit for model '{model_name}'")
121+
return cached_result
122+
105123
result = await llm_call(llm, messages, stop=stop, llm_params={"temperature": 0.01})
106124

107125
if result.lower().strip() == "off-topic":
108126
on_topic = False
109127
else:
110128
on_topic = True
111129

112-
return {"on_topic": on_topic}
130+
final_result = {"on_topic": on_topic}
131+
132+
if cache:
133+
cache_key = create_normalized_cache_key(messages)
134+
cache_entry: CacheEntry = {
135+
"result": final_result,
136+
"llm_stats": extract_llm_stats_for_cache(),
137+
"llm_metadata": extract_llm_metadata_for_cache(),
138+
}
139+
cache.put(cache_key, cache_entry)
140+
log.debug(f"Topic safety result cached for model '{model_name}'")
141+
142+
return final_result

tests/test_content_safety_cache.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
import pytest
1919

2020
from nemoguardrails.context import llm_call_info_var, llm_stats_var
21-
from nemoguardrails.library.content_safety.actions import content_safety_check_input
21+
from nemoguardrails.library.content_safety.actions import (
22+
content_safety_check_input,
23+
content_safety_check_output,
24+
)
2225
from nemoguardrails.llm.cache.lfu import LFUCache
2326
from nemoguardrails.llm.cache.utils import create_normalized_cache_key
2427
from nemoguardrails.logging.explain import LLMCallInfo
@@ -95,6 +98,7 @@ async def test_content_safety_cache_retrieves_result_and_restores_stats(
9598
"prompt_tokens": 80,
9699
"completion_tokens": 20,
97100
},
101+
"llm_metadata": None,
98102
}
99103
cache_key = create_normalized_cache_key("test prompt")
100104
cache.put(cache_key, cache_entry)
@@ -138,6 +142,7 @@ async def test_content_safety_cache_duration_reflects_cache_read_time(
138142
"prompt_tokens": 40,
139143
"completion_tokens": 10,
140144
},
145+
"llm_metadata": None,
141146
}
142147
cache_key = create_normalized_cache_key("test prompt")
143148
cache.put(cache_key, cache_entry)
@@ -191,6 +196,7 @@ async def test_content_safety_cache_handles_missing_stats_gracefully(
191196
cache_entry = {
192197
"result": {"allowed": True, "policy_violations": []},
193198
"llm_stats": None,
199+
"llm_metadata": None,
194200
}
195201
cache_key = create_normalized_cache_key("test_key")
196202
cache.put(cache_key, cache_entry)
@@ -213,3 +219,63 @@ async def test_content_safety_cache_handles_missing_stats_gracefully(
213219

214220
assert result["allowed"] is True
215221
assert llm_stats.get_stat("total_calls") == 0
222+
223+
224+
@pytest.mark.asyncio
225+
async def test_content_safety_check_output_cache_stores_result(
226+
fake_llm_with_stats, mock_task_manager
227+
):
228+
cache = LFUCache(maxsize=10)
229+
mock_task_manager.parse_task_output.return_value = [True, "policy2"]
230+
231+
result = await content_safety_check_output(
232+
llms=fake_llm_with_stats,
233+
llm_task_manager=mock_task_manager,
234+
model_name="test_model",
235+
context={"user_message": "test user input", "bot_message": "test bot response"},
236+
model_caches={"test_model": cache},
237+
)
238+
239+
assert result["allowed"] is True
240+
assert result["policy_violations"] == ["policy2"]
241+
assert cache.size() == 1
242+
243+
244+
@pytest.mark.asyncio
245+
async def test_content_safety_check_output_cache_hit(
246+
fake_llm_with_stats, mock_task_manager
247+
):
248+
cache = LFUCache(maxsize=10)
249+
250+
cache_entry = {
251+
"result": {"allowed": False, "policy_violations": ["unsafe_output"]},
252+
"llm_stats": {
253+
"total_tokens": 75,
254+
"prompt_tokens": 60,
255+
"completion_tokens": 15,
256+
},
257+
"llm_metadata": None,
258+
}
259+
cache_key = create_normalized_cache_key("test output prompt")
260+
cache.put(cache_key, cache_entry)
261+
262+
mock_task_manager.render_task_prompt.return_value = "test output prompt"
263+
264+
llm_stats = LLMStats()
265+
llm_stats_var.set(llm_stats)
266+
267+
result = await content_safety_check_output(
268+
llms=fake_llm_with_stats,
269+
llm_task_manager=mock_task_manager,
270+
model_name="test_model",
271+
context={"user_message": "user", "bot_message": "bot"},
272+
model_caches={"test_model": cache},
273+
)
274+
275+
assert result["allowed"] is False
276+
assert result["policy_violations"] == ["unsafe_output"]
277+
assert llm_stats.get_stat("total_calls") == 1
278+
assert llm_stats.get_stat("total_tokens") == 75
279+
280+
llm_call_info = llm_call_info_var.get()
281+
assert llm_call_info.from_cache is True

tests/test_topic_safety_cache.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from unittest.mock import MagicMock
17+
18+
import pytest
19+
20+
from nemoguardrails.context import llm_call_info_var, llm_stats_var
21+
from nemoguardrails.library.topic_safety.actions import topic_safety_check_input
22+
from nemoguardrails.llm.cache.lfu import LFUCache
23+
from nemoguardrails.llm.cache.utils import create_normalized_cache_key
24+
from nemoguardrails.logging.explain import LLMCallInfo
25+
from nemoguardrails.logging.stats import LLMStats
26+
from tests.utils import FakeLLM
27+
28+
29+
@pytest.fixture
30+
def mock_task_manager():
31+
tm = MagicMock()
32+
tm.render_task_prompt.return_value = "Is this on topic?"
33+
tm.get_stop_tokens.return_value = []
34+
tm.get_max_tokens.return_value = 10
35+
return tm
36+
37+
38+
@pytest.fixture
39+
def fake_llm_topic():
40+
llm = FakeLLM(responses=["on-topic"])
41+
return {"test_model": llm}
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_topic_safety_cache_stores_result(fake_llm_topic, mock_task_manager):
46+
cache = LFUCache(maxsize=10)
47+
48+
result = await topic_safety_check_input(
49+
llms=fake_llm_topic,
50+
llm_task_manager=mock_task_manager,
51+
model_name="test_model",
52+
context={"user_message": "What is AI?"},
53+
events=[],
54+
model_caches={"test_model": cache},
55+
)
56+
57+
assert result["on_topic"] is True
58+
assert cache.size() == 1
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_topic_safety_cache_hit(fake_llm_topic, mock_task_manager):
63+
cache = LFUCache(maxsize=10)
64+
65+
system_prompt_with_restriction = (
66+
"Is this on topic?\n\n"
67+
'If any of the above conditions are violated, please respond with "off-topic". '
68+
'Otherwise, respond with "on-topic". '
69+
'You must respond with "on-topic" or "off-topic".'
70+
)
71+
72+
messages = [
73+
{"type": "system", "content": system_prompt_with_restriction},
74+
{"type": "user", "content": "What is AI?"},
75+
]
76+
cache_entry = {
77+
"result": {"on_topic": False},
78+
"llm_stats": {
79+
"total_tokens": 50,
80+
"prompt_tokens": 40,
81+
"completion_tokens": 10,
82+
},
83+
"llm_metadata": None,
84+
}
85+
cache_key = create_normalized_cache_key(messages)
86+
cache.put(cache_key, cache_entry)
87+
88+
llm_stats = LLMStats()
89+
llm_stats_var.set(llm_stats)
90+
91+
llm_call_info = LLMCallInfo(task="topic_safety_check_input $model=test_model")
92+
llm_call_info_var.set(llm_call_info)
93+
94+
result = await topic_safety_check_input(
95+
llms=fake_llm_topic,
96+
llm_task_manager=mock_task_manager,
97+
model_name="test_model",
98+
context={"user_message": "What is AI?"},
99+
events=[],
100+
model_caches={"test_model": cache},
101+
)
102+
103+
assert result["on_topic"] is False
104+
assert llm_stats.get_stat("total_calls") == 1
105+
assert llm_stats.get_stat("total_tokens") == 50
106+
107+
llm_call_info = llm_call_info_var.get()
108+
assert llm_call_info.from_cache is True
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_topic_safety_without_cache(fake_llm_topic, mock_task_manager):
113+
result = await topic_safety_check_input(
114+
llms=fake_llm_topic,
115+
llm_task_manager=mock_task_manager,
116+
model_name="test_model",
117+
context={"user_message": "What is AI?"},
118+
events=[],
119+
)
120+
121+
assert result["on_topic"] is True

0 commit comments

Comments
 (0)