1212import functools
1313import inspect
1414import json
15+ import threading
1516from collections .abc import Callable , Coroutine
1617from copy import deepcopy
1718from typing import TYPE_CHECKING , Any , cast
@@ -181,6 +182,8 @@ def __init__(
181182 # Adapters can be made known to the backend (added) and loaded.
182183 self ._added_adapters : dict [str , LocalHFAdapter ] = {}
183184 self ._loaded_adapters : dict [str , LocalHFAdapter ] = {}
185+ self ._generate_lock = HFGenerationLock (self )
186+ """Necessary for generation since adapters alter the underlying model. Use '' with regular generation requests."""
184187
185188 async def generate_from_context (
186189 self ,
@@ -317,27 +320,27 @@ async def _generate_from_intrinsic(
317320 # so we will have to invalidate the cache on our side. This requires
318321 # us having specific caching for each Component/Message.
319322
320- self .load_adapter (adapter .qualified_name )
321-
322- # TODO: This modifies the underlying model. We should set a non-exclusive lock here.
323- # It should allow generate requests with the same adapter to proceed. This logic also
324- # needs to be added to the other generate functions.
325- self ._model .set_adapter (adapter .qualified_name )
326-
327323 generate_input , other_input = (
328324 granite_common .util .chat_completion_request_to_transformers_inputs (
329325 rewritten , self ._tokenizer , self ._model
330326 )
331327 )
332328
333- chat_response : Coroutine [Any , Any , granite_common .ChatCompletionResponse ] = (
334- asyncio .to_thread (
335- granite_common .util .generate_with_transformers ,
336- self ._tokenizer ,
337- self ._model ,
338- generate_input ,
339- other_input ,
340- )
329+ def generate_intrinsic_with_lock (
330+ * args , ** kwargs
331+ ) -> granite_common .ChatCompletionResponse :
332+ with self ._generate_lock .get_lock (adapter .qualified_name ):
333+ _assert_correct_adapters (adapter .qualified_name , self ._model )
334+ output = granite_common .util .generate_with_transformers (* args , ** kwargs ) # type: ignore
335+ _assert_correct_adapters (adapter .qualified_name , self ._model )
336+ return output
337+
338+ chat_response = asyncio .to_thread (
339+ generate_intrinsic_with_lock ,
340+ self ._tokenizer ,
341+ self ._model ,
342+ generate_input ,
343+ other_input ,
341344 )
342345
343346 output = ModelOutputThunk (None )
@@ -369,7 +372,6 @@ async def granite_common_processing(
369372 input_ids = generate_input ["input_tokens" ],
370373 )
371374
372- # TODO: Post-processing should release the lock for this generation.
373375 output ._post_process = functools .partial (
374376 self .post_processing ,
375377 conversation = conversation ,
@@ -489,8 +491,15 @@ async def _generate_from_context_standard(
489491 # Filter out chat template-only options before passing to generate()
490492 generate_options = self ._filter_chat_template_only_options (model_options )
491493
494+ def generate_with_lock (* args , ** kwargs ):
495+ with self ._generate_lock .get_lock ("" ):
496+ _assert_correct_adapters ("" , self ._model )
497+ output = self ._model .generate (* args , ** kwargs ) # type: ignore
498+ _assert_correct_adapters ("" , self ._model )
499+ return output
500+
492501 chat_response = asyncio .to_thread (
493- self . _model . generate , # type: ignore
502+ generate_with_lock ,
494503 input_ids ,
495504 return_dict_in_generate = True ,
496505 output_scores = True ,
@@ -664,42 +673,45 @@ async def generate_from_raw(
664673 self ._device
665674 )
666675
667- if format is None :
668- outputs = await asyncio .to_thread (
669- self ._model .generate , # type: ignore
670- input_ids = inputs ["input_ids" ],
671- attention_mask = inputs ["attention_mask" ],
672- return_dict_in_generate = True ,
673- output_scores = True ,
674- ** self ._make_backend_specific_and_remove (model_opts ),
675- )
676- else :
676+ format_kwargs = {}
677+ if format :
678+ # outlines.generate.json always parses the resulting json into a python dict.
679+ # We however want to keep it as a json string for later storing it in ModelOutputThunk
677680 schema : dict [str , Any ] = format .model_json_schema ()
678681 schema_json : str = json .dumps (schema )
679- regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema (
682+ regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema ( # type: ignore
680683 schema_json
681684 )
682685
683686 from outlines .models .transformers import TransformerTokenizer
684- from outlines .processors import RegexLogitsProcessor
687+ from outlines .processors . structured import RegexLogitsProcessor
685688 from transformers import LogitsProcessorList
686689
687- outputs = await asyncio .to_thread (
688- self ._model .generate , # type: ignore
689- input_ids = inputs ["input_ids" ],
690- attention_mask = inputs ["attention_mask" ],
691- return_dict_in_generate = True ,
692- output_scores = True ,
693- logits_processor = LogitsProcessorList (
694- [
695- RegexLogitsProcessor (
696- regex_str , tokenizer = TransformerTokenizer (self ._tokenizer )
697- )
698- ]
699- ),
700- ** self ._make_backend_specific_and_remove (model_opts ),
690+ format_kwargs ["logits_processor" ] = LogitsProcessorList (
691+ [
692+ RegexLogitsProcessor (
693+ regex_str , tokenizer = TransformerTokenizer (self ._tokenizer )
694+ )
695+ ]
701696 )
702697
698+ def generate_raw_with_lock (* args , ** kwargs ):
699+ with self ._generate_lock .get_lock ("" ):
700+ _assert_correct_adapters ("" , self ._model )
701+ output = self ._model .generate (* args , ** kwargs ) # type: ignore
702+ _assert_correct_adapters ("" , self ._model )
703+ return output
704+
705+ outputs = await asyncio .to_thread (
706+ generate_raw_with_lock ,
707+ input_ids = inputs ["input_ids" ],
708+ attention_mask = inputs ["attention_mask" ],
709+ return_dict_in_generate = True ,
710+ output_scores = True ,
711+ ** self ._make_backend_specific_and_remove (model_opts ),
712+ ** format_kwargs ,
713+ )
714+
703715 sequences_to_decode = [
704716 sequence [inputs ["input_ids" ][i ].size (0 ) :] # type: ignore
705717 for i , sequence in enumerate (outputs .sequences )
@@ -853,13 +865,16 @@ def add_adapter(self, adapter: LocalHFAdapter):
853865 self ._added_adapters [adapter .qualified_name ] = adapter
854866
855867 def load_adapter (self , adapter_qualified_name : str ):
856- """Loads the given adapter for the backend. Must have previously been added."""
868+ """Loads the given adapter for the backend. Must have previously been added. Do not manually call while generate requests are happening; will be called automatically. """
857869 adapter = self ._added_adapters .get (adapter_qualified_name , None )
858870 if adapter is None :
859871 raise ValueError (
860872 f"could not load adapter { adapter_qualified_name } for backend { self } : adapter was not previously added"
861873 )
862874
875+ if self ._loaded_adapters .get (adapter_qualified_name , None ) is not None :
876+ return # Exit early since it's already loaded.
877+
863878 try :
864879 adapter_kwargs = {}
865880
@@ -880,7 +895,6 @@ def load_adapter(self, adapter_qualified_name: str):
880895 # Loading an adapter activates it. We disable adapters immediately after.
881896 # Prefer this over `.disable_adapters()`; the disable function doesn't always
882897 # seem to work.
883- self ._model .set_adapter ([])
884898 self ._loaded_adapters [adapter .qualified_name ] = adapter
885899
886900 def unload_adapter (self , adapter_qualified_name : str ):
@@ -957,3 +971,202 @@ def stepify(self, content: str, step_separator: str) -> list[str]:
957971 step .strip () for step in content .split (step_separator ) if step .strip != ""
958972 ]
959973 return list_of_steps
974+
975+
976+ # The current implementation of this lock must be defined in this file because the backend requires a reference to
977+ # it and the lock requires an understanding of LocalHFBackends. Because generation also requires loading/activating the
978+ # correct adapter (or no adapter), integrating those load/active checks into state change of the lock made reasoning
979+ # and usage much easier. A similar approach could probably be implemented with multiple locks (which would keep this one
980+ # generic) but would potentially require users to do more work. If we need to eventually refactor this lock to support
981+ # other backends, we can do that at that point in time (since these APIs are all internal).
982+ class HFGenerationLock :
983+ """A lock-like object. Used to prevent concurrent generation from different adapters on the same backend.
984+
985+ Note: Should only be used with `asyncio.to_thread` or `threading.Thread(...).start()`. It can block if called multiple
986+ times from the same thread.
987+ """
988+
989+ def __init__ (self , backend : LocalHFBackend ):
990+ """A lock-like object. Used to prevent concurrent generation from different adapters on the same backend.
991+
992+ Notes:
993+ - Should only be used with `asyncio.to_thread` or `threading.Thread(...).start()`. It can block if called multiple times from the same thread.
994+ - This lock prioritizes acquirers with a state equal to the current state.
995+ - Typically easiest to use with `with` syntax: `with lock.get_lock(<state>): ...`
996+ """
997+ self .backend = backend
998+ """since adapter management is included in this lock, we set the backend at init"""
999+
1000+ self .current_state : str = ""
1001+ """the current state of the lock; usually reflects the model/adapter name; empty string is base model"""
1002+
1003+ self .num_active : int = 0
1004+ """counts the number of active lock holders"""
1005+
1006+ # Include a timeout to ensure there are no deadlocks caused by infinitely waiting. No deadlocks should
1007+ # occur since events are appended to the list before they attempt to acquire the lock. This means even if
1008+ # they fail to acquire the lock, the release caller will set their event to stop waiting.
1009+ # Control flow scenarios:
1010+ # - Fail to acquire lock -> immediately wait -> release is called elsewhere -> notified and acquire lock
1011+ # - Fail to acquire lock -> release is called elsewhere -> wait but see that it's already good to go -> acquire lock
1012+ self .timeout : float | None = 5
1013+ """timeout in seconds to wait before trying to acquire the lock again"""
1014+
1015+ self .lock = threading .Lock ()
1016+ """a lock to prevent concurrent modification of internal properties"""
1017+
1018+ self .events : list [threading .Event ] = []
1019+ """a list of waiters; allows notifying single or multiple waiters"""
1020+
1021+ class GenLock :
1022+ """Necessary for `with` syntax. Enables not using try-finally syntax everywhere."""
1023+
1024+ def __init__ (self , state : str , lock : HFGenerationLock ) -> None :
1025+ """Necessary for `with` syntax. Enables not using try-finally syntax everywhere.
1026+
1027+ Args:
1028+ state: the state associated with this locking operation
1029+ lock: the parent lock associated with this locking operation
1030+ """
1031+ self .state = state
1032+ self .lock = lock
1033+
1034+ def __enter__ (self ):
1035+ """Acquire the lock with a given state."""
1036+ self .lock .acquire (self .state )
1037+ return self
1038+
1039+ def __exit__ (self , exc_type , exc_val , exc_tb ):
1040+ """Release the lock."""
1041+ self .lock .release ()
1042+
1043+ # Re-raise the exception if needed.
1044+ if exc_val is not None :
1045+ raise exc_val
1046+
1047+ def get_lock (self , state ):
1048+ """Used for with statements.
1049+
1050+ Examples:
1051+ >>> # in a LocalHFBackend
1052+ >>> state = adapter.qualified_name # or "" for base model
1053+ >>> with self._generate_lock.get_lock(state):
1054+ ... ... # Generate requests here.
1055+ """
1056+ return self .GenLock (state , self )
1057+
1058+ def acquire (self , state : str ):
1059+ """Acquire the 'lock'. Only call once per thread.
1060+
1061+ Args:
1062+ state: the adapter qualified name or "" if the base model
1063+ """
1064+ # Notifier for this acquirer.
1065+ event = threading .Event ()
1066+ self .events .append (event )
1067+
1068+ while True :
1069+ if self .lock .acquire (False ):
1070+ if self .current_state == state or self .num_active == 0 :
1071+ # Breaking from this loop and the below operations must be atomic; include acquiring the lock
1072+ # as part of the condition.
1073+
1074+ # When `self.current_state == state`, this means that other generation operations with the
1075+ # current state are happening. There's no need to block this request.
1076+
1077+ # When `self.num_active == 0`, this means there's no other generation requests happening. Allow
1078+ # a single waiter to break and set the new state.
1079+ break
1080+ else :
1081+ # Have to acquire the lock to check the variables but immediately release it if comparisons fail.
1082+ self .lock .release ()
1083+
1084+ # Wait until notified of a release. Add a timeout just in case (see notes in init).
1085+ event .wait (self .timeout )
1086+
1087+ # Reset this waiter so that it patiently waits to be notified again if unable to break from the loop.
1088+ event .clear ()
1089+
1090+ self .num_active += 1
1091+
1092+ # This waiter will never wait again. Remove its event.
1093+ self .events .remove (event )
1094+
1095+ if self .current_state != state :
1096+ assert self .num_active == 1 , "only this waiter should be active"
1097+
1098+ # When swapping states, we need to make sure the correct adapter is set.
1099+ if state != "" :
1100+ # Adapter.
1101+ try :
1102+ # Ensure the adapter is loaded before setting it.
1103+ self .backend .load_adapter (state )
1104+ self .backend ._model .set_adapter (state )
1105+ except Exception as e :
1106+ # If something goes wrong, the internal state hasn't changed.
1107+ # We also have to release the internal lock so that future requests can go through.
1108+ self .lock .release ()
1109+ raise e
1110+ else :
1111+ # Base Model.
1112+ try :
1113+ # We can't know if adapters have been loaded / set previously.
1114+ # This call will throw an exception if none have been.
1115+ self .backend ._model .set_adapter ([])
1116+ except Exception :
1117+ pass
1118+
1119+ # Wait to release the lock until the current_state is set to the new state value.
1120+ self .current_state = state
1121+ self .lock .release ()
1122+
1123+ # Notify all events. Some might be using the same model/adapter.
1124+ for event in self .events :
1125+ event .set ()
1126+ else :
1127+ # Or, we immediately release the lock if we don't need to change current_state.
1128+ self .lock .release ()
1129+
1130+ def release (self ):
1131+ """Release a single hold on the lock. Should only call once per `acquire()` call."""
1132+ # Grab the internal lock to prevent concurrent modifications.
1133+ with self .lock :
1134+ self .num_active -= 1
1135+
1136+ assert self .num_active > - 1 , f"release on { self } called too many times"
1137+
1138+ # Create a local var to track the number active. This lets us release the lock
1139+ # before notifying the single waiter.
1140+ snapshot_num_active = self .num_active
1141+
1142+ # If there are no active holds on this lock, notify a single waiter if one exists.
1143+ # This also likely means that no waiters with states equal to the current_state exist;
1144+ # and a new current_state will be set.
1145+ if snapshot_num_active == 0 :
1146+ if len (self .events ) > 0 :
1147+ self .events [0 ].set ()
1148+
1149+ def __str__ (self ) -> str :
1150+ """Stringify the HFGenerationLock."""
1151+ return f"{ self .current_state } : { self .num_active } "
1152+
1153+
1154+ def _assert_correct_adapters (expected_state : str , model : PreTrainedModel ):
1155+ """When generating with a huggingface model and a hf generation lock, this can be used to ensure the correct adapters are active.
1156+
1157+ Args:
1158+ expected_state: the current state of the lock
1159+ model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
1160+ """
1161+ try :
1162+ active = model .active_adapters ()
1163+
1164+ if expected_state == "" :
1165+ assert len (active ) == 0 , f"no adapters should be active if expected state is \" \" , got \" { active [0 ]} \" "
1166+ else :
1167+ assert len (active ) == 1 , f"one adapater should be active if expected state is \" { expected_state } \" "
1168+ assert active [0 ] == expected_state , f"the active adapter \" { active [0 ]} \" doesn't match the expected state: \" { expected_state } \" "
1169+ except ValueError :
1170+ # If no weights have been loaded, the model will raise a ValueError:
1171+ # `ValueError("No adapter loaded. Please load an adapter first.")`
1172+ assert expected_state == "" , "expected state must be \" \" if no adapters have been loaded"
0 commit comments