Skip to content
This repository was archived by the owner on Sep 15, 2025. It is now read-only.

Commit e1ccb00

Browse files
committed
add LoRA layer support to IPU SD Pipeline
1 parent 8c4a1dd commit e1ccb00

File tree

1 file changed

+66
-11
lines changed

1 file changed

+66
-11
lines changed

optimum/graphcore/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_mixin.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
2121
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
22-
from diffusers.models.cross_attention import CrossAttention
22+
from diffusers.models.cross_attention import CrossAttention, LoRACrossAttnProcessor
2323
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
2424
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
2525
from transformers import CLIPTextModel
@@ -58,25 +58,39 @@ def _nearest_divisor(target, start, end):
5858
return divisor
5959
raise ValueError(f"No divisor found in range [{start}, {end}].")
6060

61-
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
61+
@staticmethod
62+
def _forward(
63+
attn: CrossAttention,
64+
hidden_states,
65+
attn_matrix_target_mem_mb,
66+
encoder_hidden_states=None,
67+
attention_mask=None,
68+
lora_cross_attn_processor=None,
69+
scale=1.0,
70+
):
6271
batch_size, sequence_length, _ = hidden_states.shape
6372

6473
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
6574

6675
query = attn.to_q(hidden_states)
76+
if lora_cross_attn_processor is not None:
77+
query += scale * lora_cross_attn_processor.to_q_lora(hidden_states)
6778
query = attn.head_to_batch_dim(query)
6879

6980
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
7081
key = attn.to_k(encoder_hidden_states)
7182
value = attn.to_v(encoder_hidden_states)
83+
if lora_cross_attn_processor is not None:
84+
key += scale * lora_cross_attn_processor.to_k_lora(encoder_hidden_states)
85+
value += scale * lora_cross_attn_processor.to_v_lora(encoder_hidden_states)
7286
key = attn.head_to_batch_dim(key)
7387
value = attn.head_to_batch_dim(value)
7488

7589
# Begin IPU modifications.
7690
attn_matrix_mem = query.element_size() * query.shape[0] * query.shape[1] * key.shape[1]
77-
num_slices = attn_matrix_mem // (self._attn_matrix_target_mem_mb * 1024 * 1024)
91+
num_slices = attn_matrix_mem // (attn_matrix_target_mem_mb * 1024 * 1024)
7892
num_slices = max(num_slices, 1)
79-
num_slices = self._nearest_divisor(query.shape[1], num_slices, 2 * num_slices)
93+
num_slices = IPUSlicedAttnProcessor._nearest_divisor(query.shape[1], num_slices, 2 * num_slices)
8094
slice_size = query.shape[1] // num_slices
8195

8296
hidden_states = []
@@ -101,11 +115,38 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
101115

102116
# linear proj
103117
hidden_states = attn.to_out[0](hidden_states)
118+
if lora_cross_attn_processor is not None:
119+
hidden_states += scale * lora_cross_attn_processor.to_out_lora(hidden_states)
104120
# dropout
105121
hidden_states = attn.to_out[1](hidden_states)
106122

107123
return hidden_states
108124

125+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
126+
return self._forward(
127+
attn, hidden_states, self._attn_matrix_target_mem_mb, encoder_hidden_states, attention_mask
128+
)
129+
130+
131+
class IPULoRASlicedAttnProcessor(torch.nn.Module):
132+
def __init__(self, attn_matrix_target_mem_mb: int, lora_cross_attn_processor: LoRACrossAttnProcessor):
133+
super().__init__()
134+
self._attn_matrix_target_mem_mb = attn_matrix_target_mem_mb
135+
self._lora_cross_attn_processor = lora_cross_attn_processor
136+
137+
def __call__(
138+
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
139+
):
140+
return IPUSlicedAttnProcessor._forward(
141+
attn,
142+
hidden_states,
143+
self._attn_matrix_target_mem_mb,
144+
encoder_hidden_states,
145+
attention_mask,
146+
self._lora_cross_attn_processor,
147+
scale,
148+
)
149+
109150

110151
class IPUCLIPTextModel(CLIPTextModel, PipelineMixin):
111152
def parallelize(self):
@@ -148,15 +189,23 @@ def forward(
148189

149190

150191
class IPUUNet2DConditionModel(UNet2DConditionModel, PipelineMixin):
151-
def change_cross_attention_processor(self, attn_matrix_target_mem_mb):
152-
for module in self.modules():
153-
if isinstance(module, CrossAttention):
154-
module.set_processor(IPUSlicedAttnProcessor(attn_matrix_target_mem_mb))
192+
def change_cross_attention_processor(self, attn_matrix_target_mem_mb, lora_name_or_path_or_dict=None):
193+
attn_processors = {}
194+
for attn_processor_name, attn_processor in self.attn_processors.items():
195+
if lora_name_or_path_or_dict is not None:
196+
attn_processors[attn_processor_name] = IPULoRASlicedAttnProcessor(
197+
attn_matrix_target_mem_mb, attn_processor
198+
)
199+
else:
200+
attn_processors[attn_processor_name] = IPUSlicedAttnProcessor(attn_matrix_target_mem_mb)
201+
self.set_attn_processor(attn_processors)
155202

156-
def parallelize(self, attn_matrix_target_mem_mb=None):
203+
def parallelize(self, attn_matrix_target_mem_mb=None, lora_name_or_path_or_dict=None):
157204
super().parallelize()
158205

159-
self.change_cross_attention_processor(attn_matrix_target_mem_mb)
206+
self.change_cross_attention_processor(
207+
attn_matrix_target_mem_mb, lora_name_or_path_or_dict=lora_name_or_path_or_dict
208+
)
160209

161210
self.conv_in = poptorch.BeginBlock(self.conv_in, "conv_in", ipu_id=0)
162211
self.down_blocks[2].downsamplers[0] = poptorch.BeginBlock(
@@ -269,6 +318,7 @@ def __init__(
269318
vae_ipu_config=None,
270319
safety_checker_ipu_config=None,
271320
common_ipu_config_kwargs=None,
321+
lora_name_or_path_or_dict=None,
272322
):
273323
default_common_ipu_config_kwargs = {
274324
"enable_half_partials": True,
@@ -399,7 +449,12 @@ def run_safety_checker(self, image, device, dtype):
399449
unet_ipu = copy.deepcopy(unet)
400450
unet_ipu.__class__ = IPUUNet2DConditionModel
401451
unet_ipu.ipu_config = unet_ipu_config
402-
unet_ipu.parallelize(attn_matrix_target_mem_mb=attn_matrix_target_mem_mb)
452+
if lora_name_or_path_or_dict is not None:
453+
unet_ipu.load_attn_procs(lora_name_or_path_or_dict)
454+
unet_ipu.parallelize(
455+
attn_matrix_target_mem_mb=attn_matrix_target_mem_mb,
456+
lora_name_or_path_or_dict=lora_name_or_path_or_dict,
457+
)
403458
override_module_eps(unet_ipu)
404459

405460
opts = unet_ipu_config.to_options(for_inference=True)

0 commit comments

Comments
 (0)