19
19
import torch
20
20
from diffusers import AutoencoderKL , StableDiffusionPipeline , UNet2DConditionModel
21
21
from diffusers .models .autoencoder_kl import AutoencoderKLOutput
22
- from diffusers .models .cross_attention import CrossAttention
22
+ from diffusers .models .cross_attention import CrossAttention , LoRACrossAttnProcessor
23
23
from diffusers .models .unet_2d_condition import UNet2DConditionOutput
24
24
from diffusers .models .vae import DecoderOutput , DiagonalGaussianDistribution
25
25
from transformers import CLIPTextModel
@@ -58,25 +58,39 @@ def _nearest_divisor(target, start, end):
58
58
return divisor
59
59
raise ValueError (f"No divisor found in range [{ start } , { end } ]." )
60
60
61
- def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
61
+ @staticmethod
62
+ def _forward (
63
+ attn : CrossAttention ,
64
+ hidden_states ,
65
+ attn_matrix_target_mem_mb ,
66
+ encoder_hidden_states = None ,
67
+ attention_mask = None ,
68
+ lora_cross_attn_processor = None ,
69
+ scale = 1.0 ,
70
+ ):
62
71
batch_size , sequence_length , _ = hidden_states .shape
63
72
64
73
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length )
65
74
66
75
query = attn .to_q (hidden_states )
76
+ if lora_cross_attn_processor is not None :
77
+ query += scale * lora_cross_attn_processor .to_q_lora (hidden_states )
67
78
query = attn .head_to_batch_dim (query )
68
79
69
80
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
70
81
key = attn .to_k (encoder_hidden_states )
71
82
value = attn .to_v (encoder_hidden_states )
83
+ if lora_cross_attn_processor is not None :
84
+ key += scale * lora_cross_attn_processor .to_k_lora (encoder_hidden_states )
85
+ value += scale * lora_cross_attn_processor .to_v_lora (encoder_hidden_states )
72
86
key = attn .head_to_batch_dim (key )
73
87
value = attn .head_to_batch_dim (value )
74
88
75
89
# Begin IPU modifications.
76
90
attn_matrix_mem = query .element_size () * query .shape [0 ] * query .shape [1 ] * key .shape [1 ]
77
- num_slices = attn_matrix_mem // (self . _attn_matrix_target_mem_mb * 1024 * 1024 )
91
+ num_slices = attn_matrix_mem // (attn_matrix_target_mem_mb * 1024 * 1024 )
78
92
num_slices = max (num_slices , 1 )
79
- num_slices = self ._nearest_divisor (query .shape [1 ], num_slices , 2 * num_slices )
93
+ num_slices = IPUSlicedAttnProcessor ._nearest_divisor (query .shape [1 ], num_slices , 2 * num_slices )
80
94
slice_size = query .shape [1 ] // num_slices
81
95
82
96
hidden_states = []
@@ -101,11 +115,38 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
101
115
102
116
# linear proj
103
117
hidden_states = attn .to_out [0 ](hidden_states )
118
+ if lora_cross_attn_processor is not None :
119
+ hidden_states += scale * lora_cross_attn_processor .to_out_lora (hidden_states )
104
120
# dropout
105
121
hidden_states = attn .to_out [1 ](hidden_states )
106
122
107
123
return hidden_states
108
124
125
+ def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
126
+ return self ._forward (
127
+ attn , hidden_states , self ._attn_matrix_target_mem_mb , encoder_hidden_states , attention_mask
128
+ )
129
+
130
+
131
+ class IPULoRASlicedAttnProcessor (torch .nn .Module ):
132
+ def __init__ (self , attn_matrix_target_mem_mb : int , lora_cross_attn_processor : LoRACrossAttnProcessor ):
133
+ super ().__init__ ()
134
+ self ._attn_matrix_target_mem_mb = attn_matrix_target_mem_mb
135
+ self ._lora_cross_attn_processor = lora_cross_attn_processor
136
+
137
+ def __call__ (
138
+ self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None , scale = 1.0
139
+ ):
140
+ return IPUSlicedAttnProcessor ._forward (
141
+ attn ,
142
+ hidden_states ,
143
+ self ._attn_matrix_target_mem_mb ,
144
+ encoder_hidden_states ,
145
+ attention_mask ,
146
+ self ._lora_cross_attn_processor ,
147
+ scale ,
148
+ )
149
+
109
150
110
151
class IPUCLIPTextModel (CLIPTextModel , PipelineMixin ):
111
152
def parallelize (self ):
@@ -148,15 +189,23 @@ def forward(
148
189
149
190
150
191
class IPUUNet2DConditionModel (UNet2DConditionModel , PipelineMixin ):
151
- def change_cross_attention_processor (self , attn_matrix_target_mem_mb ):
152
- for module in self .modules ():
153
- if isinstance (module , CrossAttention ):
154
- module .set_processor (IPUSlicedAttnProcessor (attn_matrix_target_mem_mb ))
192
+ def change_cross_attention_processor (self , attn_matrix_target_mem_mb , lora_name_or_path_or_dict = None ):
193
+ attn_processors = {}
194
+ for attn_processor_name , attn_processor in self .attn_processors .items ():
195
+ if lora_name_or_path_or_dict is not None :
196
+ attn_processors [attn_processor_name ] = IPULoRASlicedAttnProcessor (
197
+ attn_matrix_target_mem_mb , attn_processor
198
+ )
199
+ else :
200
+ attn_processors [attn_processor_name ] = IPUSlicedAttnProcessor (attn_matrix_target_mem_mb )
201
+ self .set_attn_processor (attn_processors )
155
202
156
- def parallelize (self , attn_matrix_target_mem_mb = None ):
203
+ def parallelize (self , attn_matrix_target_mem_mb = None , lora_name_or_path_or_dict = None ):
157
204
super ().parallelize ()
158
205
159
- self .change_cross_attention_processor (attn_matrix_target_mem_mb )
206
+ self .change_cross_attention_processor (
207
+ attn_matrix_target_mem_mb , lora_name_or_path_or_dict = lora_name_or_path_or_dict
208
+ )
160
209
161
210
self .conv_in = poptorch .BeginBlock (self .conv_in , "conv_in" , ipu_id = 0 )
162
211
self .down_blocks [2 ].downsamplers [0 ] = poptorch .BeginBlock (
@@ -269,6 +318,7 @@ def __init__(
269
318
vae_ipu_config = None ,
270
319
safety_checker_ipu_config = None ,
271
320
common_ipu_config_kwargs = None ,
321
+ lora_name_or_path_or_dict = None ,
272
322
):
273
323
default_common_ipu_config_kwargs = {
274
324
"enable_half_partials" : True ,
@@ -399,7 +449,12 @@ def run_safety_checker(self, image, device, dtype):
399
449
unet_ipu = copy .deepcopy (unet )
400
450
unet_ipu .__class__ = IPUUNet2DConditionModel
401
451
unet_ipu .ipu_config = unet_ipu_config
402
- unet_ipu .parallelize (attn_matrix_target_mem_mb = attn_matrix_target_mem_mb )
452
+ if lora_name_or_path_or_dict is not None :
453
+ unet_ipu .load_attn_procs (lora_name_or_path_or_dict )
454
+ unet_ipu .parallelize (
455
+ attn_matrix_target_mem_mb = attn_matrix_target_mem_mb ,
456
+ lora_name_or_path_or_dict = lora_name_or_path_or_dict ,
457
+ )
403
458
override_module_eps (unet_ipu )
404
459
405
460
opts = unet_ipu_config .to_options (for_inference = True )
0 commit comments