Skip to content

Commit f724e19

Browse files
committed
feat(cache): add caching support for topic safety and content safety output checks (#1457)
* feat(cache): add LLM metadata caching for model and provider information Extends the cache system to store and restore LLM metadata (model name and provider name) alongside cache entries. This allows cached results to maintain provenance information about which model and provider generated the original response. - Added LLMMetadataDict and LLMCacheData TypedDict definitions for type safety - Extended CacheEntry to include optional llm_metadata field - Implemented extract_llm_metadata_for_cache() to capture model and provider info from context - Implemented restore_llm_metadata_from_cache() to restore metadata when retrieving cached results - Updated get_from_cache_and_restore_stats() to handle metadata extraction and restoration - Added comprehensive test coverage for metadata caching functionalit * test(cache): add cache miss test cases for safety checks
1 parent 0eec1f6 commit f724e19

File tree

5 files changed

+334
-4
lines changed

5 files changed

+334
-4
lines changed

nemoguardrails/library/content_safety/actions.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from nemoguardrails.llm.cache.utils import (
2626
CacheEntry,
2727
create_normalized_cache_key,
28+
extract_llm_metadata_for_cache,
2829
extract_llm_stats_for_cache,
2930
get_from_cache_and_restore_stats,
3031
)
@@ -110,6 +111,7 @@ async def content_safety_check_input(
110111
cache_entry: CacheEntry = {
111112
"result": final_result,
112113
"llm_stats": extract_llm_stats_for_cache(),
114+
"llm_metadata": extract_llm_metadata_for_cache(),
113115
}
114116
cache.put(cache_key, cache_entry)
115117
log.debug(f"Content safety result cached for model '{model_name}'")
@@ -139,6 +141,7 @@ async def content_safety_check_output(
139141
llm_task_manager: LLMTaskManager,
140142
model_name: Optional[str] = None,
141143
context: Optional[dict] = None,
144+
model_caches: Optional[Dict[str, CacheInterface]] = None,
142145
**kwargs,
143146
) -> dict:
144147
_MAX_TOKENS = 3
@@ -176,12 +179,22 @@ async def content_safety_check_output(
176179
"bot_response": bot_response,
177180
},
178181
)
182+
179183
stop = llm_task_manager.get_stop_tokens(task=task)
180184
max_tokens = llm_task_manager.get_max_tokens(task=task)
181185

186+
llm_call_info_var.set(LLMCallInfo(task=task))
187+
182188
max_tokens = max_tokens or _MAX_TOKENS
183189

184-
llm_call_info_var.set(LLMCallInfo(task=task))
190+
cache = model_caches.get(model_name) if model_caches else None
191+
192+
if cache:
193+
cache_key = create_normalized_cache_key(check_output_prompt)
194+
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
195+
if cached_result is not None:
196+
log.debug(f"Content safety output cache hit for model '{model_name}'")
197+
return cached_result
185198

186199
result = await llm_call(
187200
llm,
@@ -194,4 +207,16 @@ async def content_safety_check_output(
194207

195208
is_safe, *violated_policies = result
196209

197-
return {"allowed": is_safe, "policy_violations": violated_policies}
210+
final_result = {"allowed": is_safe, "policy_violations": violated_policies}
211+
212+
if cache:
213+
cache_key = create_normalized_cache_key(check_output_prompt)
214+
cache_entry: CacheEntry = {
215+
"result": final_result,
216+
"llm_stats": extract_llm_stats_for_cache(),
217+
"llm_metadata": extract_llm_metadata_for_cache(),
218+
}
219+
cache.put(cache_key, cache_entry)
220+
log.debug(f"Content safety output result cached for model '{model_name}'")
221+
222+
return final_result

nemoguardrails/library/topic_safety/actions.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
from nemoguardrails.actions.actions import action
2222
from nemoguardrails.actions.llm.utils import llm_call
2323
from nemoguardrails.context import llm_call_info_var
24+
from nemoguardrails.llm.cache import CacheInterface
25+
from nemoguardrails.llm.cache.utils import (
26+
CacheEntry,
27+
create_normalized_cache_key,
28+
extract_llm_metadata_for_cache,
29+
extract_llm_stats_for_cache,
30+
get_from_cache_and_restore_stats,
31+
)
2432
from nemoguardrails.llm.filters import to_chat_messages
2533
from nemoguardrails.llm.taskmanager import LLMTaskManager
2634
from nemoguardrails.logging.explain import LLMCallInfo
@@ -35,6 +43,7 @@ async def topic_safety_check_input(
3543
model_name: Optional[str] = None,
3644
context: Optional[dict] = None,
3745
events: Optional[List[dict]] = None,
46+
model_caches: Optional[Dict[str, CacheInterface]] = None,
3847
**kwargs,
3948
) -> dict:
4049
_MAX_TOKENS = 10
@@ -102,11 +111,32 @@ async def topic_safety_check_input(
102111
messages.extend(conversation_history)
103112
messages.append({"type": "user", "content": user_input})
104113

114+
cache = model_caches.get(model_name) if model_caches else None
115+
116+
if cache:
117+
cache_key = create_normalized_cache_key(messages)
118+
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
119+
if cached_result is not None:
120+
log.debug(f"Topic safety cache hit for model '{model_name}'")
121+
return cached_result
122+
105123
result = await llm_call(llm, messages, stop=stop, llm_params={"temperature": 0.01})
106124

107125
if result.lower().strip() == "off-topic":
108126
on_topic = False
109127
else:
110128
on_topic = True
111129

112-
return {"on_topic": on_topic}
130+
final_result = {"on_topic": on_topic}
131+
132+
if cache:
133+
cache_key = create_normalized_cache_key(messages)
134+
cache_entry: CacheEntry = {
135+
"result": final_result,
136+
"llm_stats": extract_llm_stats_for_cache(),
137+
"llm_metadata": extract_llm_metadata_for_cache(),
138+
}
139+
cache.put(cache_key, cache_entry)
140+
log.debug(f"Topic safety result cached for model '{model_name}'")
141+
142+
return final_result

nemoguardrails/llm/cache/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def get_from_cache_and_restore_stats(
179179
if cached_metadata:
180180
restore_llm_metadata_from_cache(cached_metadata)
181181

182+
if cached_metadata:
183+
restore_llm_metadata_from_cache(cached_metadata)
184+
182185
processing_log = processing_log_var.get()
183186
if processing_log is not None:
184187
llm_call_info = llm_call_info_var.get()

tests/test_content_safety_cache.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
import pytest
1919

2020
from nemoguardrails.context import llm_call_info_var, llm_stats_var
21-
from nemoguardrails.library.content_safety.actions import content_safety_check_input
21+
from nemoguardrails.library.content_safety.actions import (
22+
content_safety_check_input,
23+
content_safety_check_output,
24+
)
2225
from nemoguardrails.llm.cache.lfu import LFUCache
2326
from nemoguardrails.llm.cache.utils import create_normalized_cache_key
2427
from nemoguardrails.logging.explain import LLMCallInfo
@@ -95,6 +98,7 @@ async def test_content_safety_cache_retrieves_result_and_restores_stats(
9598
"prompt_tokens": 80,
9699
"completion_tokens": 20,
97100
},
101+
"llm_metadata": None,
98102
}
99103
cache_key = create_normalized_cache_key("test prompt")
100104
cache.put(cache_key, cache_entry)
@@ -138,6 +142,7 @@ async def test_content_safety_cache_duration_reflects_cache_read_time(
138142
"prompt_tokens": 40,
139143
"completion_tokens": 10,
140144
},
145+
"llm_metadata": None,
141146
}
142147
cache_key = create_normalized_cache_key("test prompt")
143148
cache.put(cache_key, cache_entry)
@@ -191,6 +196,7 @@ async def test_content_safety_cache_handles_missing_stats_gracefully(
191196
cache_entry = {
192197
"result": {"allowed": True, "policy_violations": []},
193198
"llm_stats": None,
199+
"llm_metadata": None,
194200
}
195201
cache_key = create_normalized_cache_key("test_key")
196202
cache.put(cache_key, cache_entry)
@@ -213,3 +219,106 @@ async def test_content_safety_cache_handles_missing_stats_gracefully(
213219

214220
assert result["allowed"] is True
215221
assert llm_stats.get_stat("total_calls") == 0
222+
223+
224+
@pytest.mark.asyncio
225+
async def test_content_safety_check_output_cache_stores_result(
226+
fake_llm_with_stats, mock_task_manager
227+
):
228+
cache = LFUCache(maxsize=10)
229+
mock_task_manager.parse_task_output.return_value = [True, "policy2"]
230+
231+
result = await content_safety_check_output(
232+
llms=fake_llm_with_stats,
233+
llm_task_manager=mock_task_manager,
234+
model_name="test_model",
235+
context={"user_message": "test user input", "bot_message": "test bot response"},
236+
model_caches={"test_model": cache},
237+
)
238+
239+
assert result["allowed"] is True
240+
assert result["policy_violations"] == ["policy2"]
241+
assert cache.size() == 1
242+
243+
244+
@pytest.mark.asyncio
245+
async def test_content_safety_check_output_cache_hit(
246+
fake_llm_with_stats, mock_task_manager
247+
):
248+
cache = LFUCache(maxsize=10)
249+
250+
cache_entry = {
251+
"result": {"allowed": False, "policy_violations": ["unsafe_output"]},
252+
"llm_stats": {
253+
"total_tokens": 75,
254+
"prompt_tokens": 60,
255+
"completion_tokens": 15,
256+
},
257+
"llm_metadata": None,
258+
}
259+
cache_key = create_normalized_cache_key("test output prompt")
260+
cache.put(cache_key, cache_entry)
261+
262+
mock_task_manager.render_task_prompt.return_value = "test output prompt"
263+
264+
llm_stats = LLMStats()
265+
llm_stats_var.set(llm_stats)
266+
267+
result = await content_safety_check_output(
268+
llms=fake_llm_with_stats,
269+
llm_task_manager=mock_task_manager,
270+
model_name="test_model",
271+
context={"user_message": "user", "bot_message": "bot"},
272+
model_caches={"test_model": cache},
273+
)
274+
275+
assert result["allowed"] is False
276+
assert result["policy_violations"] == ["unsafe_output"]
277+
assert llm_stats.get_stat("total_calls") == 1
278+
assert llm_stats.get_stat("total_tokens") == 75
279+
280+
llm_call_info = llm_call_info_var.get()
281+
assert llm_call_info.from_cache is True
282+
283+
284+
@pytest.mark.asyncio
285+
async def test_content_safety_check_output_cache_miss(
286+
fake_llm_with_stats, mock_task_manager
287+
):
288+
cache = LFUCache(maxsize=10)
289+
290+
cache_entry = {
291+
"result": {"allowed": True, "policy_violations": []},
292+
"llm_stats": {
293+
"total_tokens": 50,
294+
"prompt_tokens": 40,
295+
"completion_tokens": 10,
296+
},
297+
"llm_metadata": None,
298+
}
299+
cache_key = create_normalized_cache_key("different prompt")
300+
cache.put(cache_key, cache_entry)
301+
302+
mock_task_manager.render_task_prompt.return_value = "new output prompt"
303+
mock_task_manager.parse_task_output.return_value = [True, "policy2"]
304+
305+
llm_stats = LLMStats()
306+
llm_stats_var.set(llm_stats)
307+
308+
llm_call_info = LLMCallInfo(task="content_safety_check_output $model=test_model")
309+
llm_call_info_var.set(llm_call_info)
310+
311+
result = await content_safety_check_output(
312+
llms=fake_llm_with_stats,
313+
llm_task_manager=mock_task_manager,
314+
model_name="test_model",
315+
context={"user_message": "new user input", "bot_message": "new bot response"},
316+
model_caches={"test_model": cache},
317+
)
318+
319+
assert result["allowed"] is True
320+
assert result["policy_violations"] == ["policy2"]
321+
assert cache.size() == 2
322+
323+
llm_call_info = llm_call_info_var.get()
324+
assert llm_call_info.from_cache is False

0 commit comments

Comments
 (0)