1
+ import inspect
2
+ import logging
3
+ import os
4
+ import time
5
+ from typing import List , Optional , Tuple
6
+
7
+ import PIL
8
+ import torch
9
+ from diffusers import FluxPipeline
10
+ from diffusers .pipelines import ImagePipelineOutput
11
+
12
+ from app .pipelines .base import Pipeline
13
+
14
+ logger = logging .getLogger (__name__ )
15
+
16
+ class LPFluxPipeline (Pipeline ):
17
+ def __init__ (self , model_id : str , device_map : str , torch_device : any , ** kwargs ):
18
+ self .lp_device_map = device_map
19
+ self .ldm = None
20
+ self .ldm2 = None
21
+ if self .lp_device_map == "FLUX_DEVICE_MAP_2_GPU" :
22
+ #setup transformer for GPU 0
23
+ self .ldm = FluxPipeline .from_pretrained (model_id , text_encoder = None , text_encoder_2 = None , tokenizer = None , tokenizer_2 = None , vae = None , ** kwargs ).to ("cuda:0" )
24
+ #setup pipeline for all other components on GPU 1
25
+ self .ldm2 = FluxPipeline .from_pretrained (model_id , unet = None , transformer = None , ** kwargs ).to ("cuda:1" )
26
+ elif self .lp_device_map == "FLUX_DEVICE_MAP_1_GPU" :
27
+ self .ldm = FluxPipeline .from_pretrained (model_id , ** kwargs )
28
+ self .ldm .enable_model_cpu_offload ()
29
+ elif "device_map" in kwargs :
30
+ self .ldm = FluxPipeline .from_pretrained (model_id , ** kwargs )
31
+ else :
32
+ self .ldm = FluxPipeline .from_pretrained (model_id , ** kwargs ).to (torch_device )
33
+
34
+ def __getattr__ (self , name ):
35
+ # Redirect attribute access to self.ldm if it exists there
36
+ try :
37
+ if name not in dir (self ):
38
+ return getattr (self .ldm , name )
39
+ else :
40
+ return super ().__getattr__ (name )
41
+
42
+ except AttributeError :
43
+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute '{ name } '" )
44
+
45
+ def __setattr__ (self , name , value ):
46
+ # Handle setting attributes
47
+ if name not in dir (self ):
48
+ # Redirect to ldm if attribute doesn't exist in this instance
49
+ setattr (self .ldm , name , value )
50
+ else :
51
+ super ().__setattr__ (name , value )
52
+
53
+ def __call__ (
54
+ self , prompt : str , ** kwargs
55
+ ) -> Tuple [List [PIL .Image ], List [Optional [bool ]]]:
56
+ outputs = None
57
+ if self .lp_device_map == "FLUX_DEVICE_MAP_2_GPU" :
58
+ with torch .no_grad ():
59
+ #generate prompt embeddings on GPU 1
60
+ start = time .time ()
61
+ prompt_embeds = pooled_prompt_embeds = text_ids = None
62
+ encode_prompt_kwargs = inspect .signature (self .ldm2 .encode_prompt ).parameters .keys ()
63
+ prompt_2 = kwargs .pop ("prompt_2" , "" )
64
+ prompt_embeds , pooled_prompt_embeds , text_ids = self .ldm2 .encode_prompt (prompt , prompt_2 , ** {k : v for k , v in kwargs .items () if k in encode_prompt_kwargs })
65
+ logger .info (f"encode_prompt took: { time .time ()- start } seconds" )
66
+ #generate the image with transformer, return latents
67
+ start = time .time ()
68
+ prompt_embeds = prompt_embeds .to (self .ldm ._execution_device )
69
+ pooled_prompt_embeds = pooled_prompt_embeds .to (self .ldm ._execution_device )
70
+ logger .info (f"prompt embeds conversion took: { time .time ()- start } seconds" )
71
+ start = time .time ()
72
+ ldm_kwargs = inspect .signature (self .ldm .__call__ ).parameters .keys ()
73
+ latents = self .ldm (prompt = None , prompt_2 = None ,
74
+ prompt_embeds = prompt_embeds .to (self .ldm ._execution_device ),
75
+ pooled_prompt_embeds = pooled_prompt_embeds .to (self .ldm ._execution_device ),
76
+ output_type = "latent" , return_dict = False ,
77
+ ** {k : v for k , v in kwargs .items () if k in ldm_kwargs })
78
+ logger .info (f"transformer took: { time .time ()- start } seconds" )
79
+ #use the VAE on GPU 1 to process the image
80
+ #copied from diffusers/pipelines/flux/pipeline_flux.py L760
81
+ start = time .time ()
82
+ latents = latents [0 ].to (self .ldm2 ._execution_device )
83
+ logger .info (f"latents conversion took: { time .time ()- start } seconds" )
84
+ start = time .time ()
85
+ latents = self .ldm2 ._unpack_latents (latents , kwargs ["height" ], kwargs ["width" ], self .ldm2 .vae_scale_factor )
86
+ latents = (latents / self .ldm2 .vae .config .scaling_factor ) + self .ldm2 .vae .config .shift_factor
87
+ image = self .ldm2 .vae .decode (latents , return_dict = False )[0 ]
88
+ image = self .ldm2 .image_processor .postprocess (image ) #only support default output_type="pil"
89
+ logger .info (f"vae decode took: { time .time ()- start } seconds" )
90
+
91
+ outputs = ImagePipelineOutput (images = image )
92
+ else :
93
+ outputs = self .ldm (prompt = prompt , ** kwargs )
94
+
95
+ return outputs
0 commit comments