Skip to content

Commit c244538

Browse files
create custom device map for FLUX pipeline
1 parent 0f5b66f commit c244538

File tree

3 files changed

+100
-12
lines changed

3 files changed

+100
-12
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .flux import LPFluxPipeline
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import inspect
2+
import logging
3+
import os
4+
import time
5+
from typing import List, Optional, Tuple
6+
7+
import PIL
8+
import torch
9+
from diffusers import FluxPipeline
10+
from diffusers.pipelines import ImagePipelineOutput
11+
12+
from app.pipelines.base import Pipeline
13+
14+
logger = logging.getLogger(__name__)
15+
16+
class LPFluxPipeline(Pipeline):
17+
def __init__(self, model_id: str, device_map: str, torch_device: any, **kwargs):
18+
self.lp_device_map = device_map
19+
self.ldm = None
20+
self.ldm2 = None
21+
if self.lp_device_map == "FLUX_DEVICE_MAP_2_GPU":
22+
#setup transformer for GPU 0
23+
self.ldm = FluxPipeline.from_pretrained(model_id, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None, vae=None, **kwargs).to("cuda:0")
24+
#setup pipeline for all other components on GPU 1
25+
self.ldm2 = FluxPipeline.from_pretrained(model_id, unet=None, transformer=None, **kwargs).to("cuda:1")
26+
elif self.lp_device_map == "FLUX_DEVICE_MAP_1_GPU":
27+
self.ldm = FluxPipeline.from_pretrained(model_id, **kwargs)
28+
self.ldm.enable_model_cpu_offload()
29+
elif "device_map" in kwargs:
30+
self.ldm = FluxPipeline.from_pretrained(model_id, **kwargs)
31+
else:
32+
self.ldm = FluxPipeline.from_pretrained(model_id, **kwargs).to(torch_device)
33+
34+
def __getattr__(self, name):
35+
# Redirect attribute access to self.ldm if it exists there
36+
try:
37+
if name not in dir(self):
38+
return getattr(self.ldm, name)
39+
else:
40+
return super().__getattr__(name)
41+
42+
except AttributeError:
43+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
44+
45+
def __setattr__(self, name, value):
46+
# Handle setting attributes
47+
if name not in dir(self):
48+
# Redirect to ldm if attribute doesn't exist in this instance
49+
setattr(self.ldm, name, value)
50+
else:
51+
super().__setattr__(name, value)
52+
53+
def __call__(
54+
self, prompt: str, **kwargs
55+
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
56+
outputs = None
57+
if self.lp_device_map == "FLUX_DEVICE_MAP_2_GPU":
58+
with torch.no_grad():
59+
#generate prompt embeddings on GPU 1
60+
start = time.time()
61+
prompt_embeds = pooled_prompt_embeds = text_ids = None
62+
encode_prompt_kwargs = inspect.signature(self.ldm2.encode_prompt).parameters.keys()
63+
prompt_2 = kwargs.pop("prompt_2", "")
64+
prompt_embeds, pooled_prompt_embeds, text_ids = self.ldm2.encode_prompt(prompt, prompt_2, **{k: v for k, v in kwargs.items() if k in encode_prompt_kwargs})
65+
logger.info(f"encode_prompt took: {time.time()-start} seconds")
66+
#generate the image with transformer, return latents
67+
start = time.time()
68+
prompt_embeds = prompt_embeds.to(self.ldm._execution_device)
69+
pooled_prompt_embeds = pooled_prompt_embeds.to(self.ldm._execution_device)
70+
logger.info(f"prompt embeds conversion took: {time.time()-start} seconds")
71+
start= time.time()
72+
ldm_kwargs = inspect.signature(self.ldm.__call__).parameters.keys()
73+
latents = self.ldm(prompt=None, prompt_2=None,
74+
prompt_embeds=prompt_embeds.to(self.ldm._execution_device),
75+
pooled_prompt_embeds=pooled_prompt_embeds.to(self.ldm._execution_device),
76+
output_type="latent", return_dict=False,
77+
**{k: v for k, v in kwargs.items() if k in ldm_kwargs})
78+
logger.info(f"transformer took: {time.time()-start} seconds")
79+
#use the VAE on GPU 1 to process the image
80+
#copied from diffusers/pipelines/flux/pipeline_flux.py L760
81+
start = time.time()
82+
latents = latents[0].to(self.ldm2._execution_device)
83+
logger.info(f"latents conversion took: {time.time()-start} seconds")
84+
start = time.time()
85+
latents = self.ldm2._unpack_latents(latents, kwargs["height"], kwargs["width"], self.ldm2.vae_scale_factor)
86+
latents = (latents / self.ldm2.vae.config.scaling_factor) + self.ldm2.vae.config.shift_factor
87+
image = self.ldm2.vae.decode(latents, return_dict=False)[0]
88+
image = self.ldm2.image_processor.postprocess(image) #only support default output_type="pil"
89+
logger.info(f"vae decode took: {time.time()-start} seconds")
90+
91+
outputs = ImagePipelineOutput(images=image)
92+
else:
93+
outputs = self.ldm(prompt=prompt, **kwargs)
94+
95+
return outputs

runner/app/pipelines/text_to_image.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from diffusers import (
99
AutoPipelineForText2Image,
1010
EulerDiscreteScheduler,
11-
FluxPipeline,
1211
StableDiffusion3Pipeline,
1312
StableDiffusionXLPipeline,
1413
UNet2DConditionModel,
@@ -17,17 +16,10 @@
1716
from huggingface_hub import file_download, hf_hub_download
1817
from safetensors.torch import load_file
1918

20-
from app.pipelines.base import Pipeline
21-
from app.pipelines.utils import (
22-
LoraLoader,
23-
SafetyChecker,
24-
get_model_dir,
25-
get_torch_device,
26-
is_lightning_model,
27-
is_turbo_model,
28-
split_prompt,
19+
20+
from app.pipelines.device_maps import (
21+
LPFluxPipeline,
2922
)
30-
from app.utils.errors import InferenceError
3123

3224
logger = logging.getLogger(__name__)
3325

@@ -137,7 +129,7 @@ def __init__(self, model_id: str):
137129
):
138130
# Decrease precision to preven OOM errors.
139131
kwargs["torch_dtype"] = torch.bfloat16
140-
self.ldm = FluxPipeline.from_pretrained(model_id, **kwargs).to(torch_device)
132+
self.ldm = LPFluxPipeline(model_id, os.environ.get("DEVICE_MAP", ""), torch_device, **kwargs)
141133
else:
142134
self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to(
143135
torch_device

0 commit comments

Comments
 (0)