1212import functools
1313import inspect
1414import json
15+ import threading
1516from collections .abc import Callable , Coroutine
1617from copy import deepcopy
1718from typing import TYPE_CHECKING , Any , cast
@@ -182,6 +183,9 @@ def __init__(
182183 self ._added_adapters : dict [str , LocalHFAdapter ] = {}
183184 self ._loaded_adapters : dict [str , LocalHFAdapter ] = {}
184185
186+ self ._generation_lock = threading .Lock ()
187+ """Used to force generation requests to be non-concurrent. Necessary for preventing issues with adapters."""
188+
185189 async def generate_from_context (
186190 self ,
187191 action : Component | CBlock ,
@@ -245,6 +249,32 @@ async def generate_from_context(
245249 )
246250 return mot , ctx .add (action ).add (mot )
247251
252+ def _generate_with_adapter_lock (
253+ self , adapter_name : str , generate_func : Callable , * args , ** kwargs
254+ ):
255+ """Helper function for ensuring exclusive generation when adapters are present. Necessary to prevent generating with incorrect weights."""
256+ with self ._generation_lock :
257+ if adapter_name != "" :
258+ self .load_adapter (adapter_name )
259+ self ._model .set_adapter (adapter_name )
260+ else :
261+ try :
262+ # `._model.disable_adapters()` doesn't seem to actually disable them or
263+ # remove them from the model's list of `.active_adapters()`.
264+ self ._model .set_adapter ([])
265+ except ValueError as e :
266+ # If no weights have been loaded, the model will raise a ValueError:
267+ # `ValueError("No adapter loaded. Please load an adapter first.")`
268+ if "No adapter loaded" in str (e ):
269+ pass
270+ else :
271+ raise e
272+
273+ _assert_correct_adapters (adapter_name , self ._model )
274+ out = generate_func (* args , ** kwargs )
275+ _assert_correct_adapters (adapter_name , self ._model )
276+ return out
277+
248278 async def _generate_from_intrinsic (
249279 self , action : Intrinsic , ctx : Context , * , model_options : dict [str , Any ]
250280 ) -> ModelOutputThunk :
@@ -328,27 +358,21 @@ async def _generate_from_intrinsic(
328358 # so we will have to invalidate the cache on our side. This requires
329359 # us having specific caching for each Component/Message.
330360
331- self .load_adapter (adapter .qualified_name )
332-
333- # TODO: This modifies the underlying model. We should set a non-exclusive lock here.
334- # It should allow generate requests with the same adapter to proceed. This logic also
335- # needs to be added to the other generate functions.
336- self ._model .set_adapter (adapter .qualified_name )
337-
338361 generate_input , other_input = (
339362 granite_common .util .chat_completion_request_to_transformers_inputs (
340363 rewritten , self ._tokenizer , self ._model
341364 )
342365 )
343366
344- chat_response : Coroutine [Any , Any , granite_common .ChatCompletionResponse ] = (
345- asyncio .to_thread (
346- granite_common .util .generate_with_transformers ,
347- self ._tokenizer ,
348- self ._model ,
349- generate_input ,
350- other_input ,
351- )
367+ chat_response = asyncio .to_thread (
368+ self ._generate_with_adapter_lock ,
369+ adapter .qualified_name ,
370+ granite_common .util .generate_with_transformers ,
371+ # Passed as args/kwargs to generate.
372+ self ._tokenizer ,
373+ self ._model ,
374+ generate_input ,
375+ other_input ,
352376 )
353377
354378 output = ModelOutputThunk (None )
@@ -501,7 +525,10 @@ async def _generate_from_context_standard(
501525 generate_options = self ._filter_chat_template_only_options (model_options )
502526
503527 chat_response = asyncio .to_thread (
528+ self ._generate_with_adapter_lock ,
529+ "" , # Empty for no adapters.
504530 self ._model .generate , # type: ignore
531+ # Passed as args/kwargs to generate.
505532 input_ids ,
506533 return_dict_in_generate = True ,
507534 output_scores = True ,
@@ -675,42 +702,41 @@ async def generate_from_raw(
675702 self ._device
676703 )
677704
678- if format is None :
679- outputs = await asyncio .to_thread (
680- self ._model .generate , # type: ignore
681- input_ids = inputs ["input_ids" ],
682- attention_mask = inputs ["attention_mask" ],
683- return_dict_in_generate = True ,
684- output_scores = True ,
685- ** self ._make_backend_specific_and_remove (model_opts ),
686- )
687- else :
705+ format_kwargs = {}
706+ if format :
707+ # outlines.generate.json always parses the resulting json into a python dict.
708+ # We however want to keep it as a json string for later storing it in ModelOutputThunk
688709 schema : dict [str , Any ] = format .model_json_schema ()
689710 schema_json : str = json .dumps (schema )
690- regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema (
711+ regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema ( # type: ignore
691712 schema_json
692713 )
693714
694715 from outlines .models .transformers import TransformerTokenizer
695- from outlines .processors import RegexLogitsProcessor
716+ from outlines .processors . structured import RegexLogitsProcessor
696717 from transformers import LogitsProcessorList
697718
698- outputs = await asyncio .to_thread (
699- self ._model .generate , # type: ignore
700- input_ids = inputs ["input_ids" ],
701- attention_mask = inputs ["attention_mask" ],
702- return_dict_in_generate = True ,
703- output_scores = True ,
704- logits_processor = LogitsProcessorList (
705- [
706- RegexLogitsProcessor (
707- regex_str , tokenizer = TransformerTokenizer (self ._tokenizer )
708- )
709- ]
710- ),
711- ** self ._make_backend_specific_and_remove (model_opts ),
719+ format_kwargs ["logits_processor" ] = LogitsProcessorList (
720+ [
721+ RegexLogitsProcessor (
722+ regex_str , tokenizer = TransformerTokenizer (self ._tokenizer )
723+ )
724+ ]
712725 )
713726
727+ outputs = await asyncio .to_thread (
728+ self ._generate_with_adapter_lock ,
729+ "" , # Empty for no adapter.
730+ self ._model .generate , # type: ignore
731+ # Passed as args/kwargs to generate.
732+ input_ids = inputs ["input_ids" ],
733+ attention_mask = inputs ["attention_mask" ],
734+ return_dict_in_generate = True ,
735+ output_scores = True ,
736+ ** self ._make_backend_specific_and_remove (model_opts ),
737+ ** format_kwargs ,
738+ )
739+
714740 sequences_to_decode = [
715741 sequence [inputs ["input_ids" ][i ].size (0 ) :] # type: ignore
716742 for i , sequence in enumerate (outputs .sequences )
@@ -864,7 +890,7 @@ def add_adapter(self, adapter: LocalHFAdapter):
864890 self ._added_adapters [adapter .qualified_name ] = adapter
865891
866892 def load_adapter (self , adapter_qualified_name : str ):
867- """Loads the given adapter for the backend. Must have previously been added."""
893+ """Loads the given adapter for the backend. Must have previously been added. Do not call when generation requests are happening. """
868894 adapter = self ._added_adapters .get (adapter_qualified_name , None )
869895 if adapter is None :
870896 raise ValueError (
@@ -891,7 +917,7 @@ def load_adapter(self, adapter_qualified_name: str):
891917 # Loading an adapter activates it. We disable adapters immediately after.
892918 # Prefer this over `.disable_adapters()`; the disable function doesn't always
893919 # seem to work.
894- self ._model .set_adapter ([] )
920+ self ._model .disable_adapters ( )
895921 self ._loaded_adapters [adapter .qualified_name ] = adapter
896922
897923 def unload_adapter (self , adapter_qualified_name : str ):
@@ -917,6 +943,38 @@ def list_adapters(self) -> list[str]:
917943 return list (self ._loaded_adapters .keys ())
918944
919945
946+ def _assert_correct_adapters (expected_state : str , model : PreTrainedModel ):
947+ """When generating with a huggingface model, this can be used to ensure the correct adapters are active.
948+
949+ Args:
950+ expected_state: the current state of the lock
951+ model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
952+ """
953+ try :
954+ active = model .active_adapters ()
955+
956+ if expected_state == "" :
957+ assert len (active ) == 0 , (
958+ f'no adapters should be active if expected state is "", got "{ active [0 ]} "'
959+ )
960+ else :
961+ assert len (active ) == 1 , (
962+ f'one adapter should be active if expected state is "{ expected_state } "'
963+ )
964+ assert active [0 ] == expected_state , (
965+ f'the active adapter "{ active [0 ]} " doesn\' t match the expected state: "{ expected_state } "'
966+ )
967+ except ValueError as e :
968+ # If no weights have been loaded, the model will raise a ValueError:
969+ # `ValueError("No adapter loaded. Please load an adapter first.")`
970+ if "No adapter loaded" in str (e ):
971+ assert expected_state == "" , (
972+ f'got no adapters loaded but expected state is "{ expected_state } "'
973+ )
974+ else :
975+ raise e
976+
977+
920978class HFProcessRewardModel (PRM , abc .ABC ):
921979 """A Process Reward Model that works with a huggingface backend."""
922980
0 commit comments