2222import time
2323import warnings
2424from contextlib import asynccontextmanager
25- from typing import Any , List , Optional
25+ from typing import Any , Callable , List , Optional
2626
2727from fastapi import FastAPI , Request
2828from fastapi .middleware .cors import CORSMiddleware
4242logging .basicConfig (level = logging .INFO )
4343log = logging .getLogger (__name__ )
4444
45+
46+ class GuardrailsApp (FastAPI ):
47+ """Custom FastAPI subclass with additional attributes for Guardrails server."""
48+
49+ def __init__ (self , * args , ** kwargs ):
50+ super ().__init__ (* args , ** kwargs )
51+ # Initialize custom attributes
52+ self .default_config_id : Optional [str ] = None
53+ self .rails_config_path : str = ""
54+ self .disable_chat_ui : bool = False
55+ self .auto_reload : bool = False
56+ self .stop_signal : bool = False
57+ self .single_config_mode : bool = False
58+ self .single_config_id : Optional [str ] = None
59+ self .loop : Optional [asyncio .AbstractEventLoop ] = None
60+ self .task : Optional [asyncio .Future ] = None
61+
62+
4563# The list of registered loggers. Can be used to send logs to various
4664# backends and storage engines.
47- registered_loggers = []
65+ registered_loggers : List [ Callable ] = []
4866
4967api_description = """Guardrails Sever API."""
5068
5169# The headers for each request
52- api_request_headers = contextvars .ContextVar ("headers" )
70+ api_request_headers : contextvars . ContextVar = contextvars .ContextVar ("headers" )
5371
5472# The datastore that the Server should use.
5573# This is currently used only for storing threads.
5977
6078
6179@asynccontextmanager
62- async def lifespan (app : FastAPI ):
80+ async def lifespan (app : GuardrailsApp ):
6381 # Startup logic here
6482 """Register any additional challenges, if available at startup."""
6583 challenges_files = os .path .join (app .rails_config_path , "challenges.json" )
@@ -82,8 +100,11 @@ async def lifespan(app: FastAPI):
82100 if os .path .exists (filepath ):
83101 filename = os .path .basename (filepath )
84102 spec = importlib .util .spec_from_file_location (filename , filepath )
85- config_module = importlib .util .module_from_spec (spec )
86- spec .loader .exec_module (config_module )
103+ if spec is not None and spec .loader is not None :
104+ config_module = importlib .util .module_from_spec (spec )
105+ spec .loader .exec_module (config_module )
106+ else :
107+ config_module = None
87108
88109 # If there is an `init` function, we call it with the reference to the app.
89110 if config_module is not None and hasattr (config_module , "init" ):
@@ -110,21 +131,22 @@ async def root_handler():
110131
111132 if app .auto_reload :
112133 app .loop = asyncio .get_running_loop ()
134+ # Store the future directly as task
113135 app .task = app .loop .run_in_executor (None , start_auto_reload_monitoring )
114136
115137 yield
116138
117139 # Shutdown logic here
118140 if app .auto_reload :
119141 app .stop_signal = True
120- if hasattr (app , "task" ):
142+ if hasattr (app , "task" ) and app . task is not None :
121143 app .task .cancel ()
122144 log .info ("Shutting down file observer" )
123145 else :
124146 pass
125147
126148
127- app = FastAPI (
149+ app = GuardrailsApp (
128150 title = "Guardrails Server API" ,
129151 description = api_description ,
130152 version = "0.1.0" ,
@@ -186,7 +208,7 @@ class RequestBody(BaseModel):
186208 max_length = 255 ,
187209 description = "The id of an existing thread to which the messages should be added." ,
188210 )
189- messages : List [dict ] = Field (
211+ messages : Optional [ List [dict ] ] = Field (
190212 default = None , description = "The list of messages in the current conversation."
191213 )
192214 context : Optional [dict ] = Field (
@@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values):
232254
233255
234256class ResponseBody (BaseModel ):
235- messages : List [dict ] = Field (
257+ messages : Optional [ List [dict ] ] = Field (
236258 default = None , description = "The new messages in the conversation"
237259 )
238260 llm_output : Optional [dict ] = Field (
@@ -282,8 +304,8 @@ async def get_rails_configs():
282304
283305
284306# One instance of LLMRails per config id
285- llm_rails_instances = {}
286- llm_rails_events_history_cache = {}
307+ llm_rails_instances : dict [ str , LLMRails ] = {}
308+ llm_rails_events_history_cache : dict [ str , dict ] = {}
287309
288310
289311def _generate_cache_key (config_ids : List [str ]) -> str :
@@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
310332 # get the same thing.
311333 config_ids = ["" ]
312334
313- full_llm_rails_config = None
335+ full_llm_rails_config : Optional [ RailsConfig ] = None
314336
315337 for config_id in config_ids :
316338 base_path = os .path .abspath (app .rails_config_path )
@@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
330352 else :
331353 full_llm_rails_config += rails_config
332354
355+ if full_llm_rails_config is None :
356+ raise ValueError ("No valid rails configuration found." )
357+
333358 llm_rails = LLMRails (config = full_llm_rails_config , verbose = True )
334359 llm_rails_instances [configs_cache_key ] = llm_rails
335360
@@ -360,30 +385,33 @@ async def chat_completion(body: RequestBody, request: Request):
360385 # Save the request headers in a context variable.
361386 api_request_headers .set (request .headers )
362387
388+ # Use Request config_ids if set, otherwise use the FastAPI default config.
389+ # If neither is available we can't generate any completions as we have no config_id
363390 config_ids = body .config_ids
364- if not config_ids and app .default_config_id :
365- config_ids = [app .default_config_id ]
366- elif not config_ids and not app .default_config_id :
367- raise GuardrailsConfigurationError (
368- "No 'config_id' provided and no default configuration is set for the server. "
369- "You must set a 'config_id' in your request or set use --default-config-id when . "
370- )
391+ if not config_ids :
392+ if app .default_config_id :
393+ config_ids = [app .default_config_id ]
394+ else :
395+ raise GuardrailsConfigurationError (
396+ "No request config_ids provided and server has no default configuration"
397+ )
398+
371399 try :
372400 llm_rails = _get_rails (config_ids )
373401 except ValueError as ex :
374402 log .exception (ex )
375- return {
376- " messages" : [
403+ return ResponseBody (
404+ messages = [
377405 {
378406 "role" : "assistant" ,
379407 "content" : f"Could not load the { config_ids } guardrails configuration. "
380408 f"An internal error has occurred." ,
381409 }
382410 ]
383- }
411+ )
384412
385413 try :
386- messages = body .messages
414+ messages = body .messages or []
387415 if body .context :
388416 messages .insert (0 , {"role" : "context" , "content" : body .context })
389417
@@ -396,14 +424,14 @@ async def chat_completion(body: RequestBody, request: Request):
396424
397425 # We make sure the `thread_id` meets the minimum complexity requirement.
398426 if len (body .thread_id ) < 16 :
399- return {
400- " messages" : [
427+ return ResponseBody (
428+ messages = [
401429 {
402430 "role" : "assistant" ,
403431 "content" : "The `thread_id` must have a minimum length of 16 characters." ,
404432 }
405433 ]
406- }
434+ )
407435
408436 # Fetch the existing thread messages. For easier management, we prepend
409437 # the string `thread-` to all thread keys.
@@ -440,32 +468,37 @@ async def chat_completion(body: RequestBody, request: Request):
440468 )
441469
442470 if isinstance (res , GenerationResponse ):
443- bot_message = res .response [0 ]
471+ bot_message_content = res .response [0 ]
472+ # Ensure bot_message is always a dict
473+ if isinstance (bot_message_content , str ):
474+ bot_message = {"role" : "assistant" , "content" : bot_message_content }
475+ else :
476+ bot_message = bot_message_content
444477 else :
445478 assert isinstance (res , dict )
446479 bot_message = res
447480
448481 # If we're using threads, we also need to update the data before returning
449482 # the message.
450- if body .thread_id :
483+ if body .thread_id and datastore is not None and datastore_key is not None :
451484 await datastore .set (datastore_key , json .dumps (messages + [bot_message ]))
452485
453- result = { " messages" : [bot_message ]}
486+ result = ResponseBody ( messages = [bot_message ])
454487
455488 # If we have additional GenerationResponse fields, we return as well
456489 if isinstance (res , GenerationResponse ):
457- result [ " llm_output" ] = res .llm_output
458- result [ " output_data" ] = res .output_data
459- result [ " log" ] = res .log
460- result [ " state" ] = res .state
490+ result . llm_output = res .llm_output
491+ result . output_data = res .output_data
492+ result . log = res .log
493+ result . state = res .state
461494
462495 return result
463496
464497 except Exception as ex :
465498 log .exception (ex )
466- return {
467- " messages" : [{"role" : "assistant" , "content" : "Internal server error." }]
468- }
499+ return ResponseBody (
500+ messages = [{"role" : "assistant" , "content" : "Internal server error." }]
501+ )
469502
470503
471504# By default, there are no challenges
@@ -498,7 +531,7 @@ def register_datastore(datastore_instance: DataStore):
498531 datastore = datastore_instance
499532
500533
501- def register_logger (logger : callable ):
534+ def register_logger (logger : Callable ):
502535 """Register an additional logger"""
503536 registered_loggers .append (logger )
504537
@@ -510,8 +543,7 @@ def start_auto_reload_monitoring():
510543 from watchdog .observers import Observer
511544
512545 class Handler (FileSystemEventHandler ):
513- @staticmethod
514- def on_any_event (event ):
546+ def on_any_event (self , event ):
515547 if event .is_directory :
516548 return None
517549
@@ -521,7 +553,8 @@ def on_any_event(event):
521553 )
522554
523555 # Compute the relative path
524- rel_path = os .path .relpath (event .src_path , app .rails_config_path )
556+ src_path_str = str (event .src_path )
557+ rel_path = os .path .relpath (src_path_str , app .rails_config_path )
525558
526559 # The config_id is the first component
527560 parts = rel_path .split (os .path .sep )
@@ -530,7 +563,7 @@ def on_any_event(event):
530563 if (
531564 not parts [- 1 ].startswith ("." )
532565 and ".ipynb_checkpoints" not in parts
533- and os .path .isfile (event . src_path )
566+ and os .path .isfile (src_path_str )
534567 ):
535568 # We just remove the config from the cache so that a new one is used next time
536569 if config_id in llm_rails_instances :
0 commit comments