Skip to content

Commit eb408e4

Browse files
committed
feat: add lock to hf backend to prevent concurrent generation with
conflicting weights
1 parent c67b90c commit eb408e4

File tree

3 files changed

+551
-46
lines changed

3 files changed

+551
-46
lines changed

mellea/backends/huggingface.py

Lines changed: 258 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
import inspect
1414
import json
15+
import threading
1516
from collections.abc import Callable, Coroutine
1617
from copy import deepcopy
1718
from typing import TYPE_CHECKING, Any, cast
@@ -181,6 +182,8 @@ def __init__(
181182
# Adapters can be made known to the backend (added) and loaded.
182183
self._added_adapters: dict[str, LocalHFAdapter] = {}
183184
self._loaded_adapters: dict[str, LocalHFAdapter] = {}
185+
self._generate_lock = HFGenerationLock(self)
186+
"""Necessary for generation since adapters alter the underlying model. Use '' with regular generation requests."""
184187

185188
async def generate_from_context(
186189
self,
@@ -317,27 +320,27 @@ async def _generate_from_intrinsic(
317320
# so we will have to invalidate the cache on our side. This requires
318321
# us having specific caching for each Component/Message.
319322

320-
self.load_adapter(adapter.qualified_name)
321-
322-
# TODO: This modifies the underlying model. We should set a non-exclusive lock here.
323-
# It should allow generate requests with the same adapter to proceed. This logic also
324-
# needs to be added to the other generate functions.
325-
self._model.set_adapter(adapter.qualified_name)
326-
327323
generate_input, other_input = (
328324
granite_common.util.chat_completion_request_to_transformers_inputs(
329325
rewritten, self._tokenizer, self._model
330326
)
331327
)
332328

333-
chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = (
334-
asyncio.to_thread(
335-
granite_common.util.generate_with_transformers,
336-
self._tokenizer,
337-
self._model,
338-
generate_input,
339-
other_input,
340-
)
329+
def generate_intrinsic_with_lock(
330+
*args, **kwargs
331+
) -> granite_common.ChatCompletionResponse:
332+
with self._generate_lock.get_lock(adapter.qualified_name):
333+
_assert_correct_adapters(adapter.qualified_name, self._model)
334+
output = granite_common.util.generate_with_transformers(*args, **kwargs) # type: ignore
335+
_assert_correct_adapters(adapter.qualified_name, self._model)
336+
return output
337+
338+
chat_response = asyncio.to_thread(
339+
generate_intrinsic_with_lock,
340+
self._tokenizer,
341+
self._model,
342+
generate_input,
343+
other_input,
341344
)
342345

343346
output = ModelOutputThunk(None)
@@ -369,7 +372,6 @@ async def granite_common_processing(
369372
input_ids=generate_input["input_tokens"],
370373
)
371374

372-
# TODO: Post-processing should release the lock for this generation.
373375
output._post_process = functools.partial(
374376
self.post_processing,
375377
conversation=conversation,
@@ -489,8 +491,15 @@ async def _generate_from_context_standard(
489491
# Filter out chat template-only options before passing to generate()
490492
generate_options = self._filter_chat_template_only_options(model_options)
491493

494+
def generate_with_lock(*args, **kwargs):
495+
with self._generate_lock.get_lock(""):
496+
_assert_correct_adapters("", self._model)
497+
output = self._model.generate(*args, **kwargs) # type: ignore
498+
_assert_correct_adapters("", self._model)
499+
return output
500+
492501
chat_response = asyncio.to_thread(
493-
self._model.generate, # type: ignore
502+
generate_with_lock,
494503
input_ids,
495504
return_dict_in_generate=True,
496505
output_scores=True,
@@ -664,42 +673,45 @@ async def generate_from_raw(
664673
self._device
665674
)
666675

667-
if format is None:
668-
outputs = await asyncio.to_thread(
669-
self._model.generate, # type: ignore
670-
input_ids=inputs["input_ids"],
671-
attention_mask=inputs["attention_mask"],
672-
return_dict_in_generate=True,
673-
output_scores=True,
674-
**self._make_backend_specific_and_remove(model_opts),
675-
)
676-
else:
676+
format_kwargs = {}
677+
if format:
678+
# outlines.generate.json always parses the resulting json into a python dict.
679+
# We however want to keep it as a json string for later storing it in ModelOutputThunk
677680
schema: dict[str, Any] = format.model_json_schema()
678681
schema_json: str = json.dumps(schema)
679-
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
682+
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
680683
schema_json
681684
)
682685

683686
from outlines.models.transformers import TransformerTokenizer
684-
from outlines.processors import RegexLogitsProcessor
687+
from outlines.processors.structured import RegexLogitsProcessor
685688
from transformers import LogitsProcessorList
686689

687-
outputs = await asyncio.to_thread(
688-
self._model.generate, # type: ignore
689-
input_ids=inputs["input_ids"],
690-
attention_mask=inputs["attention_mask"],
691-
return_dict_in_generate=True,
692-
output_scores=True,
693-
logits_processor=LogitsProcessorList(
694-
[
695-
RegexLogitsProcessor(
696-
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
697-
)
698-
]
699-
),
700-
**self._make_backend_specific_and_remove(model_opts),
690+
format_kwargs["logits_processor"] = LogitsProcessorList(
691+
[
692+
RegexLogitsProcessor(
693+
regex_str, tokenizer=TransformerTokenizer(self._tokenizer)
694+
)
695+
]
701696
)
702697

698+
def generate_raw_with_lock(*args, **kwargs):
699+
with self._generate_lock.get_lock(""):
700+
_assert_correct_adapters("", self._model)
701+
output = self._model.generate(*args, **kwargs) # type: ignore
702+
_assert_correct_adapters("", self._model)
703+
return output
704+
705+
outputs = await asyncio.to_thread(
706+
generate_raw_with_lock,
707+
input_ids=inputs["input_ids"],
708+
attention_mask=inputs["attention_mask"],
709+
return_dict_in_generate=True,
710+
output_scores=True,
711+
**self._make_backend_specific_and_remove(model_opts),
712+
**format_kwargs,
713+
)
714+
703715
sequences_to_decode = [
704716
sequence[inputs["input_ids"][i].size(0) :] # type: ignore
705717
for i, sequence in enumerate(outputs.sequences)
@@ -853,13 +865,16 @@ def add_adapter(self, adapter: LocalHFAdapter):
853865
self._added_adapters[adapter.qualified_name] = adapter
854866

855867
def load_adapter(self, adapter_qualified_name: str):
856-
"""Loads the given adapter for the backend. Must have previously been added."""
868+
"""Loads the given adapter for the backend. Must have previously been added. Do not manually call while generate requests are happening; will be called automatically."""
857869
adapter = self._added_adapters.get(adapter_qualified_name, None)
858870
if adapter is None:
859871
raise ValueError(
860872
f"could not load adapter {adapter_qualified_name} for backend {self}: adapter was not previously added"
861873
)
862874

875+
if self._loaded_adapters.get(adapter_qualified_name, None) is not None:
876+
return # Exit early since it's already loaded.
877+
863878
try:
864879
adapter_kwargs = {}
865880

@@ -880,7 +895,6 @@ def load_adapter(self, adapter_qualified_name: str):
880895
# Loading an adapter activates it. We disable adapters immediately after.
881896
# Prefer this over `.disable_adapters()`; the disable function doesn't always
882897
# seem to work.
883-
self._model.set_adapter([])
884898
self._loaded_adapters[adapter.qualified_name] = adapter
885899

886900
def unload_adapter(self, adapter_qualified_name: str):
@@ -957,3 +971,202 @@ def stepify(self, content: str, step_separator: str) -> list[str]:
957971
step.strip() for step in content.split(step_separator) if step.strip != ""
958972
]
959973
return list_of_steps
974+
975+
976+
# The current implementation of this lock must be defined in this file because the backend requires a reference to
977+
# it and the lock requires an understanding of LocalHFBackends. Because generation also requires loading/activating the
978+
# correct adapter (or no adapter), integrating those load/active checks into state change of the lock made reasoning
979+
# and usage much easier. A similar approach could probably be implemented with multiple locks (which would keep this one
980+
# generic) but would potentially require users to do more work. If we need to eventually refactor this lock to support
981+
# other backends, we can do that at that point in time (since these APIs are all internal).
982+
class HFGenerationLock:
983+
"""A lock-like object. Used to prevent concurrent generation from different adapters on the same backend.
984+
985+
Note: Should only be used with `asyncio.to_thread` or `threading.Thread(...).start()`. It can block if called multiple
986+
times from the same thread.
987+
"""
988+
989+
def __init__(self, backend: LocalHFBackend):
990+
"""A lock-like object. Used to prevent concurrent generation from different adapters on the same backend.
991+
992+
Notes:
993+
- Should only be used with `asyncio.to_thread` or `threading.Thread(...).start()`. It can block if called multiple times from the same thread.
994+
- This lock prioritizes acquirers with a state equal to the current state.
995+
- Typically easiest to use with `with` syntax: `with lock.get_lock(<state>): ...`
996+
"""
997+
self.backend = backend
998+
"""since adapter management is included in this lock, we set the backend at init"""
999+
1000+
self.current_state: str = ""
1001+
"""the current state of the lock; usually reflects the model/adapter name; empty string is base model"""
1002+
1003+
self.num_active: int = 0
1004+
"""counts the number of active lock holders"""
1005+
1006+
# Include a timeout to ensure there are no deadlocks caused by infinitely waiting. No deadlocks should
1007+
# occur since events are appended to the list before they attempt to acquire the lock. This means even if
1008+
# they fail to acquire the lock, the release caller will set their event to stop waiting.
1009+
# Control flow scenarios:
1010+
# - Fail to acquire lock -> immediately wait -> release is called elsewhere -> notified and acquire lock
1011+
# - Fail to acquire lock -> release is called elsewhere -> wait but see that it's already good to go -> acquire lock
1012+
self.timeout: float | None = 5
1013+
"""timeout in seconds to wait before trying to acquire the lock again"""
1014+
1015+
self.lock = threading.Lock()
1016+
"""a lock to prevent concurrent modification of internal properties"""
1017+
1018+
self.events: list[threading.Event] = []
1019+
"""a list of waiters; allows notifying single or multiple waiters"""
1020+
1021+
class GenLock:
1022+
"""Necessary for `with` syntax. Enables not using try-finally syntax everywhere."""
1023+
1024+
def __init__(self, state: str, lock: HFGenerationLock) -> None:
1025+
"""Necessary for `with` syntax. Enables not using try-finally syntax everywhere.
1026+
1027+
Args:
1028+
state: the state associated with this locking operation
1029+
lock: the parent lock associated with this locking operation
1030+
"""
1031+
self.state = state
1032+
self.lock = lock
1033+
1034+
def __enter__(self):
1035+
"""Acquire the lock with a given state."""
1036+
self.lock.acquire(self.state)
1037+
return self
1038+
1039+
def __exit__(self, exc_type, exc_val, exc_tb):
1040+
"""Release the lock."""
1041+
self.lock.release()
1042+
1043+
# Re-raise the exception if needed.
1044+
if exc_val is not None:
1045+
raise exc_val
1046+
1047+
def get_lock(self, state):
1048+
"""Used for with statements.
1049+
1050+
Examples:
1051+
>>> # in a LocalHFBackend
1052+
>>> state = adapter.qualified_name # or "" for base model
1053+
>>> with self._generate_lock.get_lock(state):
1054+
... ... # Generate requests here.
1055+
"""
1056+
return self.GenLock(state, self)
1057+
1058+
def acquire(self, state: str):
1059+
"""Acquire the 'lock'. Only call once per thread.
1060+
1061+
Args:
1062+
state: the adapter qualified name or "" if the base model
1063+
"""
1064+
# Notifier for this acquirer.
1065+
event = threading.Event()
1066+
self.events.append(event)
1067+
1068+
while True:
1069+
if self.lock.acquire(False):
1070+
if self.current_state == state or self.num_active == 0:
1071+
# Breaking from this loop and the below operations must be atomic; include acquiring the lock
1072+
# as part of the condition.
1073+
1074+
# When `self.current_state == state`, this means that other generation operations with the
1075+
# current state are happening. There's no need to block this request.
1076+
1077+
# When `self.num_active == 0`, this means there's no other generation requests happening. Allow
1078+
# a single waiter to break and set the new state.
1079+
break
1080+
else:
1081+
# Have to acquire the lock to check the variables but immediately release it if comparisons fail.
1082+
self.lock.release()
1083+
1084+
# Wait until notified of a release. Add a timeout just in case (see notes in init).
1085+
event.wait(self.timeout)
1086+
1087+
# Reset this waiter so that it patiently waits to be notified again if unable to break from the loop.
1088+
event.clear()
1089+
1090+
self.num_active += 1
1091+
1092+
# This waiter will never wait again. Remove its event.
1093+
self.events.remove(event)
1094+
1095+
if self.current_state != state:
1096+
assert self.num_active == 1, "only this waiter should be active"
1097+
1098+
# When swapping states, we need to make sure the correct adapter is set.
1099+
if state != "":
1100+
# Adapter.
1101+
try:
1102+
# Ensure the adapter is loaded before setting it.
1103+
self.backend.load_adapter(state)
1104+
self.backend._model.set_adapter(state)
1105+
except Exception as e:
1106+
# If something goes wrong, the internal state hasn't changed.
1107+
# We also have to release the internal lock so that future requests can go through.
1108+
self.lock.release()
1109+
raise e
1110+
else:
1111+
# Base Model.
1112+
try:
1113+
# We can't know if adapters have been loaded / set previously.
1114+
# This call will throw an exception if none have been.
1115+
self.backend._model.set_adapter([])
1116+
except Exception:
1117+
pass
1118+
1119+
# Wait to release the lock until the current_state is set to the new state value.
1120+
self.current_state = state
1121+
self.lock.release()
1122+
1123+
# Notify all events. Some might be using the same model/adapter.
1124+
for event in self.events:
1125+
event.set()
1126+
else:
1127+
# Or, we immediately release the lock if we don't need to change current_state.
1128+
self.lock.release()
1129+
1130+
def release(self):
1131+
"""Release a single hold on the lock. Should only call once per `acquire()` call."""
1132+
# Grab the internal lock to prevent concurrent modifications.
1133+
with self.lock:
1134+
self.num_active -= 1
1135+
1136+
assert self.num_active > -1, f"release on {self} called too many times"
1137+
1138+
# Create a local var to track the number active. This lets us release the lock
1139+
# before notifying the single waiter.
1140+
snapshot_num_active = self.num_active
1141+
1142+
# If there are no active holds on this lock, notify a single waiter if one exists.
1143+
# This also likely means that no waiters with states equal to the current_state exist;
1144+
# and a new current_state will be set.
1145+
if snapshot_num_active == 0:
1146+
if len(self.events) > 0:
1147+
self.events[0].set()
1148+
1149+
def __str__(self) -> str:
1150+
"""Stringify the HFGenerationLock."""
1151+
return f"{self.current_state}: {self.num_active}"
1152+
1153+
1154+
def _assert_correct_adapters(expected_state: str, model: PreTrainedModel):
1155+
"""When generating with a huggingface model and a hf generation lock, this can be used to ensure the correct adapters are active.
1156+
1157+
Args:
1158+
expected_state: the current state of the lock
1159+
model: the model underlying the LocalHFBackend; this is the model the adapters are activated on
1160+
"""
1161+
try:
1162+
active = model.active_adapters()
1163+
1164+
if expected_state == "":
1165+
assert len(active) == 0, f"no adapters should be active if expected state is \"\", got \"{active[0]}\""
1166+
else:
1167+
assert len(active) == 1, f"one adapater should be active if expected state is \"{expected_state}\""
1168+
assert active[0] == expected_state, f"the active adapter \"{active[0]}\" doesn't match the expected state: \"{expected_state}\""
1169+
except ValueError:
1170+
# If no weights have been loaded, the model will raise a ValueError:
1171+
# `ValueError("No adapter loaded. Please load an adapter first.")`
1172+
assert expected_state == "", "expected state must be \"\" if no adapters have been loaded"

test/backends/test_huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_constraint_lora_override_does_not_override_alora(session, backend):
120120
# the correct actions / results in it.
121121
assert isinstance(val_result.context, Context)
122122
assert isinstance(val_result.thunk, ModelOutputThunk)
123-
assert isinstance(val_result.context.previous_node.node_data, ALoraRequirement)
123+
assert isinstance(val_result.context.previous_node.node_data, ALoraRequirement) # type: ignore
124124
assert val_result.context.node_data is val_result.thunk
125125

126126
backend.default_to_constraint_checking_alora = True

0 commit comments

Comments
 (0)