Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
bbfc5f4
Basic implementation of request scheduling
FredyRivera-dev Sep 6, 2025
a308e3e
Basic editing in SD and Flux Pipelines
FredyRivera-dev Sep 7, 2025
4799b8e
Small Fix
FredyRivera-dev Sep 7, 2025
eda5847
Fix
FredyRivera-dev Sep 7, 2025
6b5e6be
Update for more pipelines
FredyRivera-dev Sep 7, 2025
df2933f
Add examples/server-async
FredyRivera-dev Sep 7, 2025
5c7c7c6
Add examples/server-async
FredyRivera-dev Sep 7, 2025
e3cd368
Merge branch 'huggingface:main' into main
FredyRivera-dev Sep 10, 2025
09bf796
Merge branch 'huggingface:main' into main
FredyRivera-dev Sep 10, 2025
bd3e48a
Updated RequestScopedPipeline to handle a single tokenizer lock to av…
FredyRivera-dev Sep 10, 2025
534710c
Fix
FredyRivera-dev Sep 10, 2025
4d7c64f
Fix _TokenizerLockWrapper
FredyRivera-dev Sep 10, 2025
18db9e6
Fix _TokenizerLockWrapper
FredyRivera-dev Sep 10, 2025
8f0efb1
Delete _TokenizerLockWrapper
FredyRivera-dev Sep 10, 2025
b479039
Fix tokenizer
FredyRivera-dev Sep 10, 2025
e676b34
Merge branch 'huggingface:main' into main
FredyRivera-dev Sep 10, 2025
0beab1c
Update examples/server-async
FredyRivera-dev Sep 11, 2025
840f0e4
Fix server-async
FredyRivera-dev Sep 11, 2025
bb41c2b
Merge branch 'huggingface:main' into main
FredyRivera-dev Sep 12, 2025
8a238c3
Merge branch 'huggingface:main' into main
FredyRivera-dev Sep 12, 2025
ed617fe
Optimizations in examples/server-async
FredyRivera-dev Sep 13, 2025
b052d27
We keep the implementation simple in examples/server-async
FredyRivera-dev Sep 14, 2025
0f63f4d
Update examples/server-async/README.md
FredyRivera-dev Sep 14, 2025
a9666b1
Update examples/server-async/README.md for changes to tokenizer locks…
FredyRivera-dev Sep 14, 2025
06bb136
The changes to the diffusers core have been undone and all logic is b…
FredyRivera-dev Sep 15, 2025
a519915
Update examples/server-async/utils/*
FredyRivera-dev Sep 15, 2025
7cfee77
Fix BaseAsyncScheduler
FredyRivera-dev Sep 15, 2025
e574f07
Rollback in the core of the diffusers
FredyRivera-dev Sep 15, 2025
05d7936
Merge branch 'huggingface:main' into main
FredyRivera-dev Sep 15, 2025
1049663
Update examples/server-async/README.md
FredyRivera-dev Sep 15, 2025
5316620
Complete rollback of diffusers core files
FredyRivera-dev Sep 15, 2025
0ecdfc3
Simple implementation of an asynchronous server compatible with SD3-3…
FredyRivera-dev Sep 17, 2025
ac5c9e6
Update examples/server-async/README.md
FredyRivera-dev Sep 17, 2025
72e0215
Fixed import errors in 'examples/server-async/serverasync.py'
FredyRivera-dev Sep 17, 2025
edd550b
Flux Pipeline Discard
FredyRivera-dev Sep 17, 2025
6b69367
Update examples/server-async/README.md
FredyRivera-dev Sep 17, 2025
5598557
Merge branch 'main' into main
sayakpaul Sep 18, 2025
7c4f883
Apply style fixes
github-actions[bot] Sep 18, 2025
c91e6f4
Merge branch 'huggingface:main' into main
FredyRivera-dev Oct 20, 2025
f2e9f02
Add thread-safe wrappers for components in pipeline
FredyRivera-dev Oct 21, 2025
489da5d
Add wrappers.py
FredyRivera-dev Oct 21, 2025
8072fba
Merge branch 'main' into main
sayakpaul Oct 21, 2025
581847f
Apply style fixes
github-actions[bot] Oct 21, 2025
7f01e69
Merge branch 'main' into main
sayakpaul Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 87 additions & 47 deletions examples/server-async/utils/requestscopedpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@
from diffusers.utils import logging

from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper


logger = logging.get_logger(__name__)


def safe_tokenize(tokenizer, *args, lock, **kwargs):
with lock:
return tokenizer(*args, **kwargs)


class RequestScopedPipeline:
DEFAULT_MUTABLE_ATTRS = [
"_all_hooks",
Expand All @@ -38,23 +34,40 @@ def __init__(
wrap_scheduler: bool = True,
):
self._base = pipeline

self.unet = getattr(pipeline, "unet", None)
self.vae = getattr(pipeline, "vae", None)
self.text_encoder = getattr(pipeline, "text_encoder", None)
self.components = getattr(pipeline, "components", None)

self.transformer = getattr(pipeline, "transformer", None)

if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)

self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)

self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()

self._vae_lock = threading.Lock()
self._image_lock = threading.Lock()

self._auto_detect_mutables = bool(auto_detect_mutables)
self._tensor_numel_threshold = int(tensor_numel_threshold)

self._auto_detected_attrs: List[str] = []

def _detect_kernel_pipeline(self, pipeline) -> bool:
kernel_indicators = [
"text_encoding_cache",
"memory_manager",
"enable_optimizations",
"_create_request_context",
"get_optimization_stats",
]

return any(hasattr(pipeline, attr) for attr in kernel_indicators)

def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
base_sched = getattr(self._base, "scheduler", None)
if base_sched is None:
Expand All @@ -70,11 +83,21 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
)
except Exception as e:
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
try:
return copy.deepcopy(wrapped_scheduler)
except Exception as e:
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
if hasattr(wrapped_scheduler, "scheduler"):
try:
copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
return BaseAsyncScheduler(copied_scheduler)
except Exception:
return wrapped_scheduler
else:
copied_scheduler = copy.copy(wrapped_scheduler)
return BaseAsyncScheduler(copied_scheduler)
except Exception as e2:
logger.warning(
f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)."
)
return wrapped_scheduler

def _autodetect_mutables(self, max_attrs: int = 40):
Expand All @@ -86,25 +109,25 @@ def _autodetect_mutables(self, max_attrs: int = 40):

candidates: List[str] = []
seen = set()

for name in dir(self._base):
if name.startswith("__"):
continue
if name in self._mutable_attrs:
continue
if name in ("to", "save_pretrained", "from_pretrained"):
continue

try:
val = getattr(self._base, name)
except Exception:
continue

import types

# skip callables and modules
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
continue

# containers -> candidate
if isinstance(val, (dict, list, set, tuple, bytearray)):
candidates.append(name)
seen.add(name)
Expand Down Expand Up @@ -205,6 +228,9 @@ def _is_tokenizer_component(self, component) -> bool:

return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)

def _should_wrap_tokenizers(self) -> bool:
return True

def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)

Expand All @@ -214,6 +240,25 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
local_pipe = copy.deepcopy(self._base)

try:
if (
hasattr(local_pipe, "vae")
and local_pipe.vae is not None
and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)
):
local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)

if (
hasattr(local_pipe, "image_processor")
and local_pipe.image_processor is not None
and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)
):
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(
local_pipe.image_processor, self._image_lock
)
except Exception as e:
logger.debug(f"Could not wrap vae/image_processor: {e}")

if local_scheduler is not None:
try:
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
Expand All @@ -231,66 +276,61 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =

self._clone_mutable_attrs(self._base, local_pipe)

# 4) wrap tokenizers on the local pipe with the lock wrapper
tokenizer_wrappers = {} # name -> original_tokenizer
try:
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
tokenizer_wrappers[name] = tok
setattr(
local_pipe,
name,
lambda *args, tok=tok, **kwargs: safe_tokenize(
tok, *args, lock=self._tokenizer_lock, **kwargs
),
)

# b) wrap tokenizers in components dict
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue

if self._is_tokenizer_component(val):
tokenizer_wrappers[f"components[{key}]"] = val
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
)
original_tokenizers = {}

except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
if self._should_wrap_tokenizers():
try:
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
if not isinstance(tok, ThreadSafeTokenizerWrapper):
original_tokenizers[name] = tok
wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)
setattr(local_pipe, name, wrapped_tokenizer)

if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue

if self._is_tokenizer_component(val):
if not isinstance(val, ThreadSafeTokenizerWrapper):
original_tokenizers[f"components[{key}]"] = val
wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)
local_pipe.components[key] = wrapped_tokenizer

except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")

result = None
cm = getattr(local_pipe, "model_cpu_offload_context", None)

try:
if callable(cm):
try:
with cm():
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except TypeError:
# cm might be a context manager instance rather than callable
try:
with cm:
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except Exception as e:
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
else:
# no offload context available — call directly
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)

return result

finally:
try:
for name, tok in tokenizer_wrappers.items():
for name, tok in original_tokenizers.items():
if name.startswith("components["):
key = name[len("components[") : -1]
local_pipe.components[key] = tok
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
local_pipe.components[key] = tok
else:
setattr(local_pipe, name, tok)
except Exception as e:
logger.debug(f"Error restoring wrapped tokenizers: {e}")
logger.debug(f"Error restoring original tokenizers: {e}")
86 changes: 86 additions & 0 deletions examples/server-async/utils/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
class ThreadSafeTokenizerWrapper:
def __init__(self, tokenizer, lock):
self._tokenizer = tokenizer
self._lock = lock

self._thread_safe_methods = {
"__call__",
"encode",
"decode",
"tokenize",
"encode_plus",
"batch_encode_plus",
"batch_decode",
}

def __getattr__(self, name):
attr = getattr(self._tokenizer, name)

if name in self._thread_safe_methods and callable(attr):

def wrapped_method(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)

return wrapped_method

return attr

def __call__(self, *args, **kwargs):
with self._lock:
return self._tokenizer(*args, **kwargs)

def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._tokenizer, name, value)

def __dir__(self):
return dir(self._tokenizer)


class ThreadSafeVAEWrapper:
def __init__(self, vae, lock):
self._vae = vae
self._lock = lock

def __getattr__(self, name):
attr = getattr(self._vae, name)
if name in {"decode", "encode", "forward"} and callable(attr):

def wrapped(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)

return wrapped
return attr

def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._vae, name, value)


class ThreadSafeImageProcessorWrapper:
def __init__(self, proc, lock):
self._proc = proc
self._lock = lock

def __getattr__(self, name):
attr = getattr(self._proc, name)
if name in {"postprocess", "preprocess"} and callable(attr):

def wrapped(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)

return wrapped
return attr

def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._proc, name, value)
Loading