Skip to content

Commit a50fd16

Browse files
committed
fix: add simple lock to hf generation to prevent using incorrect weights
1 parent 3087051 commit a50fd16

File tree

2 files changed

+320
-44
lines changed

2 files changed

+320
-44
lines changed

mellea/backends/huggingface.py

Lines changed: 101 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
import inspect
1414
import json
15+
import threading
1516
from collections.abc import Callable, Coroutine
1617
from copy import deepcopy
1718
from typing import TYPE_CHECKING, Any, cast
@@ -182,6 +183,9 @@ def __init__(
182183
self._added_adapters: dict[str, LocalHFAdapter] = {}
183184
self._loaded_adapters: dict[str, LocalHFAdapter] = {}
184185

186+
self._generation_lock = threading.Lock()
187+
"""Used to force generation requests to be non-concurrent. Necessary for preventing issues with adapters."""
188+
185189
async def generate_from_context(
186190
self,
187191
action: Component | CBlock,
@@ -245,6 +249,32 @@ async def generate_from_context(
245249
)
246250
return mot, ctx.add(action).add(mot)
247251

252+
def _generate_with_adapter_lock(
253+
self, adapter_name: str, generate_func: Callable, *args, **kwargs
254+
):
255+
"""Helper function for ensuring exclusive generation when adapters are present. Necessary to prevent generating with incorrect weights."""
256+
with self._generation_lock:
257+
if adapter_name != "":
258+
self.load_adapter(adapter_name)
259+
self._model.set_adapter(adapter_name)
260+
else:
261+
try:
262+
# `._model.disable_adapters()` doesn't seem to actually disable them or
263+
# remove them from the model's list of `.active_adapters()`.
264+
self._model.set_adapter([])
265+
except ValueError as e:
266+
# If no weights have been loaded, the model will raise a ValueError:
267+
# `ValueError("No adapter loaded. Please load an adapter first.")`
268+
if "No adapter loaded" in str(e):
269+
pass
270+
else:
271+
raise e
272+
273+
_assert_correct_adapters(adapter_name, self._model)
274+
out = generate_func(*args, **kwargs)
275+
_assert_correct_adapters(adapter_name, self._model)
276+
return out
277+
248278
async def _generate_from_intrinsic(
249279
self, action: Intrinsic, ctx: Context, *, model_options: dict[str, Any]
250280
) -> ModelOutputThunk:
@@ -317,27 +347,21 @@ async def _generate_from_intrinsic(
317347
# so we will have to invalidate the cache on our side. This requires
318348
# us having specific caching for each Component/Message.
319349

320-
self.load_adapter(adapter.qualified_name)
321-
322-
# TODO: This modifies the underlying model. We should set a non-exclusive lock here.
323-
# It should allow generate requests with the same adapter to proceed. This logic also
324-
# needs to be added to the other generate functions.
325-
self._model.set_adapter(adapter.qualified_name)
326-
327350
generate_input, other_input = (
328351
granite_common.util.chat_completion_request_to_transformers_inputs(
329352
rewritten, self._tokenizer, self._model
330353
)
331354
)
332355

333-
chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = (
334-
asyncio.to_thread(
335-
granite_common.util.generate_with_transformers,
336-
self._tokenizer,
337-
self._model,
338-
generate_input,
339-
other_input,
340-
)
356+
chat_response = asyncio.to_thread(
357+
self._generate_with_adapter_lock,
358+
adapter.qualified_name,
359+
granite_common.util.generate_with_transformers,
360+
# Passed as args/kwargs to generate.
361+
self._tokenizer,
362+
self._model,
363+
generate_input,
364+
other_input,
341365
)
342366

343367
output = ModelOutputThunk(None)
@@ -490,7 +514,10 @@ async def _generate_from_context_standard(
490514
generate_options = self._filter_chat_template_only_options(model_options)
491515

492516
chat_response = asyncio.to_thread(
517+
self._generate_with_adapter_lock,
518+
"", # Empty for no adapters.
493519
self._model.generate, # type: ignore
520+
# Passed as args/kwargs to generate.
494521
input_ids,
495522
return_dict_in_generate=True,
496523
output_scores=True,
@@ -664,42 +691,41 @@ async def generate_from_raw(
664691
self._device
665692
)
666693

667-
if format is None:
668-
outputs = await asyncio.to_thread(
669-
self._model.generate, # type: ignore
670-
input_ids=inputs["input_ids"],
671-
attention_mask=inputs["attention_mask"],
672-
return_dict_in_generate=True,
673-
output_scores=True,
674-
**self._make_backend_specific_and_remove(model_opts),
675-
)
676-
else:
694+
format_kwargs = {}
695+
if format:
696+
# outlines.generate.json always parses the resulting json into a python dict.
697+
# We however want to keep it as a json string for later storing it in ModelOutputThunk
677698
schema: dict[str, Any] = format.model_json_schema()
678699
schema_json: str = json.dumps(schema)
679-
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
700+
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
680701
schema_json
681702
)
682703

683704
from outlines.models.transformers import TransformerTokenizer
684-
from outlines.processors import RegexLogitsProcessor
705+
from outlines.processors.structured import RegexLogitsProcessor
685706
from transformers import LogitsProcessorList
686707

687-
outputs = await asyncio.to_thread(
688-
self._model.generate, # type: ignore
689-
input_ids=inputs["input_ids"],
690-
attention_mask=inputs["attention_mask"],
691-
return_dict_in_generate=True,
692-
output_scores=True,
693-
logits_processor=LogitsProcessorList(
694-
[
695-
RegexLogitsProcessor(
696-
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
697-
)
698-
]
699-
),
700-
**self._make_backend_specific_and_remove(model_opts),
708+
format_kwargs["logits_processor"] = LogitsProcessorList(
709+
[
710+
RegexLogitsProcessor(
711+
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
712+
)
713+
]
701714
)
702715

716+
outputs = await asyncio.to_thread(
717+
self._generate_with_adapter_lock,
718+
"", # Empty for no adapter.
719+
self._model.generate, # type: ignore
720+
# Passed as args/kwargs to generate.
721+
input_ids=inputs["input_ids"],
722+
attention_mask=inputs["attention_mask"],
723+
return_dict_in_generate=True,
724+
output_scores=True,
725+
**self._make_backend_specific_and_remove(model_opts),
726+
**format_kwargs,
727+
)
728+
703729
sequences_to_decode = [
704730
sequence[inputs["input_ids"][i].size(0) :] # type: ignore
705731
for i, sequence in enumerate(outputs.sequences)
@@ -853,7 +879,7 @@ def add_adapter(self, adapter: LocalHFAdapter):
853879
self._added_adapters[adapter.qualified_name] = adapter
854880

855881
def load_adapter(self, adapter_qualified_name: str):
856-
"""Loads the given adapter for the backend. Must have previously been added."""
882+
"""Loads the given adapter for the backend. Must have previously been added. Do not call when generation requests are happening."""
857883
adapter = self._added_adapters.get(adapter_qualified_name, None)
858884
if adapter is None:
859885
raise ValueError(
@@ -880,7 +906,7 @@ def load_adapter(self, adapter_qualified_name: str):
880906
# Loading an adapter activates it. We disable adapters immediately after.
881907
# Prefer this over `.disable_adapters()`; the disable function doesn't always
882908
# seem to work.
883-
self._model.set_adapter([])
909+
self._model.disable_adapters()
884910
self._loaded_adapters[adapter.qualified_name] = adapter
885911

886912
def unload_adapter(self, adapter_qualified_name: str):
@@ -906,6 +932,38 @@ def list_adapters(self) -> list[str]:
906932
return list(self._loaded_adapters.keys())
907933

908934

935+
def _assert_correct_adapters(expected_state: str, model: PreTrainedModel):
936+
"""When generating with a huggingface model, this can be used to ensure the correct adapters are active.
937+
938+
Args:
939+
expected_state: the current state of the lock
940+
model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
941+
"""
942+
try:
943+
active = model.active_adapters()
944+
945+
if expected_state == "":
946+
assert len(active) == 0, (
947+
f'no adapters should be active if expected state is "", got "{active[0]}"'
948+
)
949+
else:
950+
assert len(active) == 1, (
951+
f'one adapter should be active if expected state is "{expected_state}"'
952+
)
953+
assert active[0] == expected_state, (
954+
f'the active adapter "{active[0]}" doesn\'t match the expected state: "{expected_state}"'
955+
)
956+
except ValueError as e:
957+
# If no weights have been loaded, the model will raise a ValueError:
958+
# `ValueError("No adapter loaded. Please load an adapter first.")`
959+
if "No adapter loaded" in str(e):
960+
assert expected_state == "", (
961+
f'got no adapters loaded but expected state is "{expected_state}"'
962+
)
963+
else:
964+
raise e
965+
966+
909967
class HFProcessRewardModel(PRM, abc.ABC):
910968
"""A Process Reward Model that works with a huggingface backend."""
911969

0 commit comments

Comments
 (0)