3030
3131import logging
3232import os
33- from typing import Optional
33+ from time import time
34+ from typing import Dict , Optional
3435
3536from nemoguardrails .actions import action
37+ from nemoguardrails .context import llm_call_info_var
3638from nemoguardrails .library .jailbreak_detection .request import (
3739 jailbreak_detection_heuristics_request ,
3840 jailbreak_detection_model_request ,
3941 jailbreak_nim_request ,
4042)
43+ from nemoguardrails .llm .cache import CacheInterface
44+ from nemoguardrails .llm .cache .utils import (
45+ CacheEntry ,
46+ create_normalized_cache_key ,
47+ get_from_cache_and_restore_stats ,
48+ )
4149from nemoguardrails .llm .taskmanager import LLMTaskManager
50+ from nemoguardrails .logging .explain import LLMCallInfo
51+ from nemoguardrails .logging .processing_log import processing_log_var
4252
4353log = logging .getLogger (__name__ )
4454
@@ -89,6 +99,7 @@ async def jailbreak_detection_heuristics(
8999async def jailbreak_detection_model (
90100 llm_task_manager : LLMTaskManager ,
91101 context : Optional [dict ] = None ,
102+ model_caches : Optional [Dict [str , CacheInterface ]] = None ,
92103) -> bool :
93104 """Uses a trained classifier to determine if a user input is a jailbreak attempt"""
94105 prompt : str = ""
@@ -102,6 +113,30 @@ async def jailbreak_detection_model(
102113 if context is not None :
103114 prompt = context .get ("user_message" , "" )
104115
116+ # we do this as a hack to treat this action as an LLM call for tracing
117+ llm_call_info_var .set (LLMCallInfo (task = "jailbreak_detection_model" ))
118+
119+ cache = model_caches .get ("jailbreak_detection" ) if model_caches else None
120+
121+ if cache :
122+ cache_key = create_normalized_cache_key (prompt )
123+ cache_read_start = time ()
124+ cached_result = get_from_cache_and_restore_stats (cache , cache_key )
125+ if cached_result is not None :
126+ cache_read_duration = time () - cache_read_start
127+ llm_call_info = llm_call_info_var .get ()
128+ if llm_call_info :
129+ llm_call_info .from_cache = True
130+ llm_call_info .duration = cache_read_duration
131+ llm_call_info .started_at = time () - cache_read_duration
132+ llm_call_info .finished_at = time ()
133+
134+ log .debug ("Jailbreak detection cache hit" )
135+ return cached_result ["jailbreak" ]
136+
137+ jailbreak_result = None
138+ api_start_time = time ()
139+
105140 if not jailbreak_api_url and not nim_base_url :
106141 from nemoguardrails .library .jailbreak_detection .model_based .checks import (
107142 check_jailbreak ,
@@ -114,32 +149,64 @@ async def jailbreak_detection_model(
114149 try :
115150 jailbreak = check_jailbreak (prompt = prompt )
116151 log .info (f"Local model jailbreak detection result: { jailbreak } " )
117- return jailbreak ["jailbreak" ]
152+ jailbreak_result = jailbreak ["jailbreak" ]
118153 except RuntimeError as e :
119154 log .error (f"Jailbreak detection model not available: { e } " )
120- return False
155+ jailbreak_result = False
121156 except ImportError as e :
122157 log .error (
123158 f"Failed to import required dependencies for local model. Install scikit-learn and torch, or use NIM-based approach" ,
124159 exc_info = e ,
125160 )
126- return False
127-
128- if nim_base_url :
129- jailbreak = await jailbreak_nim_request (
130- prompt = prompt ,
131- nim_url = nim_base_url ,
132- nim_auth_token = nim_auth_token ,
133- nim_classification_path = nim_classification_path ,
134- )
135- elif jailbreak_api_url :
136- jailbreak = await jailbreak_detection_model_request (
137- prompt = prompt , api_url = jailbreak_api_url
138- )
139-
140- if jailbreak is None :
141- log .warning ("Jailbreak endpoint not set up properly." )
142- # If no result, assume not a jailbreak
143- return False
161+ jailbreak_result = False
144162 else :
145- return jailbreak
163+ if nim_base_url :
164+ jailbreak = await jailbreak_nim_request (
165+ prompt = prompt ,
166+ nim_url = nim_base_url ,
167+ nim_auth_token = nim_auth_token ,
168+ nim_classification_path = nim_classification_path ,
169+ )
170+ elif jailbreak_api_url :
171+ jailbreak = await jailbreak_detection_model_request (
172+ prompt = prompt , api_url = jailbreak_api_url
173+ )
174+
175+ if jailbreak is None :
176+ log .warning ("Jailbreak endpoint not set up properly." )
177+ jailbreak_result = False
178+ else :
179+ jailbreak_result = jailbreak
180+
181+ api_duration = time () - api_start_time
182+
183+ llm_call_info = llm_call_info_var .get ()
184+ if llm_call_info :
185+ llm_call_info .from_cache = False
186+ llm_call_info .duration = api_duration
187+ llm_call_info .started_at = api_start_time
188+ llm_call_info .finished_at = time ()
189+
190+ processing_log = processing_log_var .get ()
191+ if processing_log is not None :
192+ processing_log .append (
193+ {
194+ "type" : "llm_call_info" ,
195+ "timestamp" : time (),
196+ "data" : llm_call_info ,
197+ }
198+ )
199+
200+ if cache :
201+ from nemoguardrails .llm .cache .utils import extract_llm_metadata_for_cache
202+
203+ cache_key = create_normalized_cache_key (prompt )
204+ cache_entry : CacheEntry = {
205+ "result" : {"jailbreak" : jailbreak_result },
206+ "llm_stats" : None ,
207+ "llm_metadata" : extract_llm_metadata_for_cache (),
208+ }
209+ cache .put (cache_key , cache_entry )
210+ log .debug ("Jailbreak detection result cached" )
211+
212+ return jailbreak_result
0 commit comments