1212import functools
1313import inspect
1414import json
15+ import threading
1516from collections .abc import Callable , Coroutine
1617from copy import deepcopy
1718from typing import TYPE_CHECKING , Any , cast
@@ -182,6 +183,9 @@ def __init__(
182183 self ._added_adapters : dict [str , LocalHFAdapter ] = {}
183184 self ._loaded_adapters : dict [str , LocalHFAdapter ] = {}
184185
186+ self ._generation_lock = threading .Lock ()
187+ """Used to force generation requests to be non-concurrent. Necessary for preventing issues with adapters."""
188+
185189 async def generate_from_context (
186190 self ,
187191 action : Component | CBlock ,
@@ -245,6 +249,32 @@ async def generate_from_context(
245249 )
246250 return mot , ctx .add (action ).add (mot )
247251
252+ def _generate_with_adapter_lock (
253+ self , adapter_name : str , generate_func : Callable , * args , ** kwargs
254+ ):
255+ """Helper function for ensuring exclusive generation when adapters are present. Necessary to prevent generating with incorrect weights."""
256+ with self ._generation_lock :
257+ if adapter_name != "" :
258+ self .load_adapter (adapter_name )
259+ self ._model .set_adapter (adapter_name )
260+ else :
261+ try :
262+ # `._model.disable_adapters()` doesn't seem to actually disable them or
263+ # remove them from the model's list of `.active_adapters()`.
264+ self ._model .set_adapter ([])
265+ except ValueError as e :
266+ # If no weights have been loaded, the model will raise a ValueError:
267+ # `ValueError("No adapter loaded. Please load an adapter first.")`
268+ if "No adapter loaded" in str (e ):
269+ pass
270+ else :
271+ raise e
272+
273+ _assert_correct_adapters (adapter_name , self ._model )
274+ out = generate_func (* args , ** kwargs )
275+ _assert_correct_adapters (adapter_name , self ._model )
276+ return out
277+
248278 async def _generate_from_intrinsic (
249279 self , action : Intrinsic , ctx : Context , * , model_options : dict [str , Any ]
250280 ) -> ModelOutputThunk :
@@ -317,27 +347,21 @@ async def _generate_from_intrinsic(
317347 # so we will have to invalidate the cache on our side. This requires
318348 # us having specific caching for each Component/Message.
319349
320- self .load_adapter (adapter .qualified_name )
321-
322- # TODO: This modifies the underlying model. We should set a non-exclusive lock here.
323- # It should allow generate requests with the same adapter to proceed. This logic also
324- # needs to be added to the other generate functions.
325- self ._model .set_adapter (adapter .qualified_name )
326-
327350 generate_input , other_input = (
328351 granite_common .util .chat_completion_request_to_transformers_inputs (
329352 rewritten , self ._tokenizer , self ._model
330353 )
331354 )
332355
333- chat_response : Coroutine [Any , Any , granite_common .ChatCompletionResponse ] = (
334- asyncio .to_thread (
335- granite_common .util .generate_with_transformers ,
336- self ._tokenizer ,
337- self ._model ,
338- generate_input ,
339- other_input ,
340- )
356+ chat_response = asyncio .to_thread (
357+ self ._generate_with_adapter_lock ,
358+ adapter .qualified_name ,
359+ granite_common .util .generate_with_transformers ,
360+ # Passed as args/kwargs to generate.
361+ self ._tokenizer ,
362+ self ._model ,
363+ generate_input ,
364+ other_input ,
341365 )
342366
343367 output = ModelOutputThunk (None )
@@ -490,7 +514,10 @@ async def _generate_from_context_standard(
490514 generate_options = self ._filter_chat_template_only_options (model_options )
491515
492516 chat_response = asyncio .to_thread (
517+ self ._generate_with_adapter_lock ,
518+ "" , # Empty for no adapters.
493519 self ._model .generate , # type: ignore
520+ # Passed as args/kwargs to generate.
494521 input_ids ,
495522 return_dict_in_generate = True ,
496523 output_scores = True ,
@@ -664,42 +691,41 @@ async def generate_from_raw(
664691 self ._device
665692 )
666693
667- if format is None :
668- outputs = await asyncio .to_thread (
669- self ._model .generate , # type: ignore
670- input_ids = inputs ["input_ids" ],
671- attention_mask = inputs ["attention_mask" ],
672- return_dict_in_generate = True ,
673- output_scores = True ,
674- ** self ._make_backend_specific_and_remove (model_opts ),
675- )
676- else :
694+ format_kwargs = {}
695+ if format :
696+ # outlines.generate.json always parses the resulting json into a python dict.
697+ # We however want to keep it as a json string for later storing it in ModelOutputThunk
677698 schema : dict [str , Any ] = format .model_json_schema ()
678699 schema_json : str = json .dumps (schema )
679- regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema (
700+ regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema ( # type: ignore
680701 schema_json
681702 )
682703
683704 from outlines .models .transformers import TransformerTokenizer
684- from outlines .processors import RegexLogitsProcessor
705+ from outlines .processors . structured import RegexLogitsProcessor
685706 from transformers import LogitsProcessorList
686707
687- outputs = await asyncio .to_thread (
688- self ._model .generate , # type: ignore
689- input_ids = inputs ["input_ids" ],
690- attention_mask = inputs ["attention_mask" ],
691- return_dict_in_generate = True ,
692- output_scores = True ,
693- logits_processor = LogitsProcessorList (
694- [
695- RegexLogitsProcessor (
696- regex_str , tokenizer = TransformerTokenizer (self ._tokenizer )
697- )
698- ]
699- ),
700- ** self ._make_backend_specific_and_remove (model_opts ),
708+ format_kwargs ["logits_processor" ] = LogitsProcessorList (
709+ [
710+ RegexLogitsProcessor (
711+ regex_str , tokenizer = TransformerTokenizer (self ._tokenizer )
712+ )
713+ ]
701714 )
702715
716+ outputs = await asyncio .to_thread (
717+ self ._generate_with_adapter_lock ,
718+ "" , # Empty for no adapter.
719+ self ._model .generate , # type: ignore
720+ # Passed as args/kwargs to generate.
721+ input_ids = inputs ["input_ids" ],
722+ attention_mask = inputs ["attention_mask" ],
723+ return_dict_in_generate = True ,
724+ output_scores = True ,
725+ ** self ._make_backend_specific_and_remove (model_opts ),
726+ ** format_kwargs ,
727+ )
728+
703729 sequences_to_decode = [
704730 sequence [inputs ["input_ids" ][i ].size (0 ) :] # type: ignore
705731 for i , sequence in enumerate (outputs .sequences )
@@ -853,7 +879,7 @@ def add_adapter(self, adapter: LocalHFAdapter):
853879 self ._added_adapters [adapter .qualified_name ] = adapter
854880
855881 def load_adapter (self , adapter_qualified_name : str ):
856- """Loads the given adapter for the backend. Must have previously been added."""
882+ """Loads the given adapter for the backend. Must have previously been added. Do not call when generation requests are happening. """
857883 adapter = self ._added_adapters .get (adapter_qualified_name , None )
858884 if adapter is None :
859885 raise ValueError (
@@ -880,7 +906,7 @@ def load_adapter(self, adapter_qualified_name: str):
880906 # Loading an adapter activates it. We disable adapters immediately after.
881907 # Prefer this over `.disable_adapters()`; the disable function doesn't always
882908 # seem to work.
883- self ._model .set_adapter ([] )
909+ self ._model .disable_adapters ( )
884910 self ._loaded_adapters [adapter .qualified_name ] = adapter
885911
886912 def unload_adapter (self , adapter_qualified_name : str ):
@@ -906,6 +932,38 @@ def list_adapters(self) -> list[str]:
906932 return list (self ._loaded_adapters .keys ())
907933
908934
935+ def _assert_correct_adapters (expected_state : str , model : PreTrainedModel ):
936+ """When generating with a huggingface model, this can be used to ensure the correct adapters are active.
937+
938+ Args:
939+ expected_state: the current state of the lock
940+ model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
941+ """
942+ try :
943+ active = model .active_adapters ()
944+
945+ if expected_state == "" :
946+ assert len (active ) == 0 , (
947+ f'no adapters should be active if expected state is "", got "{ active [0 ]} "'
948+ )
949+ else :
950+ assert len (active ) == 1 , (
951+ f'one adapter should be active if expected state is "{ expected_state } "'
952+ )
953+ assert active [0 ] == expected_state , (
954+ f'the active adapter "{ active [0 ]} " doesn\' t match the expected state: "{ expected_state } "'
955+ )
956+ except ValueError as e :
957+ # If no weights have been loaded, the model will raise a ValueError:
958+ # `ValueError("No adapter loaded. Please load an adapter first.")`
959+ if "No adapter loaded" in str (e ):
960+ assert expected_state == "" , (
961+ f'got no adapters loaded but expected state is "{ expected_state } "'
962+ )
963+ else :
964+ raise e
965+
966+
909967class HFProcessRewardModel (PRM , abc .ABC ):
910968 """A Process Reward Model that works with a huggingface backend."""
911969
0 commit comments