Skip to content

Commit 6b2a527

Browse files
authored
fix: add simple lock to hf generation to prevent using incorrect weights (#237)
1 parent 2120112 commit 6b2a527

File tree

2 files changed

+345
-44
lines changed

2 files changed

+345
-44
lines changed

mellea/backends/huggingface.py

Lines changed: 101 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
import inspect
1414
import json
15+
import threading
1516
from collections.abc import Callable, Coroutine
1617
from copy import deepcopy
1718
from typing import TYPE_CHECKING, Any, cast
@@ -182,6 +183,9 @@ def __init__(
182183
self._added_adapters: dict[str, LocalHFAdapter] = {}
183184
self._loaded_adapters: dict[str, LocalHFAdapter] = {}
184185

186+
self._generation_lock = threading.Lock()
187+
"""Used to force generation requests to be non-concurrent. Necessary for preventing issues with adapters."""
188+
185189
async def generate_from_context(
186190
self,
187191
action: Component | CBlock,
@@ -245,6 +249,32 @@ async def generate_from_context(
245249
)
246250
return mot, ctx.add(action).add(mot)
247251

252+
def _generate_with_adapter_lock(
253+
self, adapter_name: str, generate_func: Callable, *args, **kwargs
254+
):
255+
"""Helper function for ensuring exclusive generation when adapters are present. Necessary to prevent generating with incorrect weights."""
256+
with self._generation_lock:
257+
if adapter_name != "":
258+
self.load_adapter(adapter_name)
259+
self._model.set_adapter(adapter_name)
260+
else:
261+
try:
262+
# `._model.disable_adapters()` doesn't seem to actually disable them or
263+
# remove them from the model's list of `.active_adapters()`.
264+
self._model.set_adapter([])
265+
except ValueError as e:
266+
# If no weights have been loaded, the model will raise a ValueError:
267+
# `ValueError("No adapter loaded. Please load an adapter first.")`
268+
if "No adapter loaded" in str(e):
269+
pass
270+
else:
271+
raise e
272+
273+
_assert_correct_adapters(adapter_name, self._model)
274+
out = generate_func(*args, **kwargs)
275+
_assert_correct_adapters(adapter_name, self._model)
276+
return out
277+
248278
async def _generate_from_intrinsic(
249279
self, action: Intrinsic, ctx: Context, *, model_options: dict[str, Any]
250280
) -> ModelOutputThunk:
@@ -328,27 +358,21 @@ async def _generate_from_intrinsic(
328358
# so we will have to invalidate the cache on our side. This requires
329359
# us having specific caching for each Component/Message.
330360

331-
self.load_adapter(adapter.qualified_name)
332-
333-
# TODO: This modifies the underlying model. We should set a non-exclusive lock here.
334-
# It should allow generate requests with the same adapter to proceed. This logic also
335-
# needs to be added to the other generate functions.
336-
self._model.set_adapter(adapter.qualified_name)
337-
338361
generate_input, other_input = (
339362
granite_common.util.chat_completion_request_to_transformers_inputs(
340363
rewritten, self._tokenizer, self._model
341364
)
342365
)
343366

344-
chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = (
345-
asyncio.to_thread(
346-
granite_common.util.generate_with_transformers,
347-
self._tokenizer,
348-
self._model,
349-
generate_input,
350-
other_input,
351-
)
367+
chat_response = asyncio.to_thread(
368+
self._generate_with_adapter_lock,
369+
adapter.qualified_name,
370+
granite_common.util.generate_with_transformers,
371+
# Passed as args/kwargs to generate.
372+
self._tokenizer,
373+
self._model,
374+
generate_input,
375+
other_input,
352376
)
353377

354378
output = ModelOutputThunk(None)
@@ -501,7 +525,10 @@ async def _generate_from_context_standard(
501525
generate_options = self._filter_chat_template_only_options(model_options)
502526

503527
chat_response = asyncio.to_thread(
528+
self._generate_with_adapter_lock,
529+
"", # Empty for no adapters.
504530
self._model.generate, # type: ignore
531+
# Passed as args/kwargs to generate.
505532
input_ids,
506533
return_dict_in_generate=True,
507534
output_scores=True,
@@ -675,42 +702,41 @@ async def generate_from_raw(
675702
self._device
676703
)
677704

678-
if format is None:
679-
outputs = await asyncio.to_thread(
680-
self._model.generate, # type: ignore
681-
input_ids=inputs["input_ids"],
682-
attention_mask=inputs["attention_mask"],
683-
return_dict_in_generate=True,
684-
output_scores=True,
685-
**self._make_backend_specific_and_remove(model_opts),
686-
)
687-
else:
705+
format_kwargs = {}
706+
if format:
707+
# outlines.generate.json always parses the resulting json into a python dict.
708+
# We however want to keep it as a json string for later storing it in ModelOutputThunk
688709
schema: dict[str, Any] = format.model_json_schema()
689710
schema_json: str = json.dumps(schema)
690-
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
711+
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
691712
schema_json
692713
)
693714

694715
from outlines.models.transformers import TransformerTokenizer
695-
from outlines.processors import RegexLogitsProcessor
716+
from outlines.processors.structured import RegexLogitsProcessor
696717
from transformers import LogitsProcessorList
697718

698-
outputs = await asyncio.to_thread(
699-
self._model.generate, # type: ignore
700-
input_ids=inputs["input_ids"],
701-
attention_mask=inputs["attention_mask"],
702-
return_dict_in_generate=True,
703-
output_scores=True,
704-
logits_processor=LogitsProcessorList(
705-
[
706-
RegexLogitsProcessor(
707-
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
708-
)
709-
]
710-
),
711-
**self._make_backend_specific_and_remove(model_opts),
719+
format_kwargs["logits_processor"] = LogitsProcessorList(
720+
[
721+
RegexLogitsProcessor(
722+
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
723+
)
724+
]
712725
)
713726

727+
outputs = await asyncio.to_thread(
728+
self._generate_with_adapter_lock,
729+
"", # Empty for no adapter.
730+
self._model.generate, # type: ignore
731+
# Passed as args/kwargs to generate.
732+
input_ids=inputs["input_ids"],
733+
attention_mask=inputs["attention_mask"],
734+
return_dict_in_generate=True,
735+
output_scores=True,
736+
**self._make_backend_specific_and_remove(model_opts),
737+
**format_kwargs,
738+
)
739+
714740
sequences_to_decode = [
715741
sequence[inputs["input_ids"][i].size(0) :] # type: ignore
716742
for i, sequence in enumerate(outputs.sequences)
@@ -864,7 +890,7 @@ def add_adapter(self, adapter: LocalHFAdapter):
864890
self._added_adapters[adapter.qualified_name] = adapter
865891

866892
def load_adapter(self, adapter_qualified_name: str):
867-
"""Loads the given adapter for the backend. Must have previously been added."""
893+
"""Loads the given adapter for the backend. Must have previously been added. Do not call when generation requests are happening."""
868894
adapter = self._added_adapters.get(adapter_qualified_name, None)
869895
if adapter is None:
870896
raise ValueError(
@@ -891,7 +917,7 @@ def load_adapter(self, adapter_qualified_name: str):
891917
# Loading an adapter activates it. We disable adapters immediately after.
892918
# Prefer this over `.disable_adapters()`; the disable function doesn't always
893919
# seem to work.
894-
self._model.set_adapter([])
920+
self._model.disable_adapters()
895921
self._loaded_adapters[adapter.qualified_name] = adapter
896922

897923
def unload_adapter(self, adapter_qualified_name: str):
@@ -917,6 +943,38 @@ def list_adapters(self) -> list[str]:
917943
return list(self._loaded_adapters.keys())
918944

919945

946+
def _assert_correct_adapters(expected_state: str, model: PreTrainedModel):
947+
"""When generating with a huggingface model, this can be used to ensure the correct adapters are active.
948+
949+
Args:
950+
expected_state: the current state of the lock
951+
model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
952+
"""
953+
try:
954+
active = model.active_adapters()
955+
956+
if expected_state == "":
957+
assert len(active) == 0, (
958+
f'no adapters should be active if expected state is "", got "{active[0]}"'
959+
)
960+
else:
961+
assert len(active) == 1, (
962+
f'one adapter should be active if expected state is "{expected_state}"'
963+
)
964+
assert active[0] == expected_state, (
965+
f'the active adapter "{active[0]}" doesn\'t match the expected state: "{expected_state}"'
966+
)
967+
except ValueError as e:
968+
# If no weights have been loaded, the model will raise a ValueError:
969+
# `ValueError("No adapter loaded. Please load an adapter first.")`
970+
if "No adapter loaded" in str(e):
971+
assert expected_state == "", (
972+
f'got no adapters loaded but expected state is "{expected_state}"'
973+
)
974+
else:
975+
raise e
976+
977+
920978
class HFProcessRewardModel(PRM, abc.ABC):
921979
"""A Process Reward Model that works with a huggingface backend."""
922980

0 commit comments

Comments
 (0)