Skip to content

Commit bcf0181

Browse files
authored
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirement
1 parent 7b38964 commit bcf0181

File tree

15 files changed

+1089
-16
lines changed

15 files changed

+1089
-16
lines changed

colossalai/inference/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
## 📌 Introduction
21-
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
21+
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
2222

2323
<p align="center">
2424
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/colossal-inference-v1-1.png" width=1000/>
@@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
310310
journal={arXiv},
311311
year={2023}
312312
}
313+
314+
# Distrifusion
315+
@InProceedings{Li_2024_CVPR,
316+
author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
317+
title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
318+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
319+
month={June},
320+
year={2024},
321+
pages={7183-7193}
322+
}
313323
```

colossalai/inference/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
186186
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
187187
start_token_size(int): The size of the start tokens, when using StreamingLLM.
188188
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
189+
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
189190
"""
190191

191192
# NOTE: arrange configs according to their importance and frequency of usage
@@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
245246
start_token_size: int = 4
246247
generated_token_size: int = 512
247248

249+
# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
250+
patched_parallelism_size: int = 1 # for distrifusion
251+
# pipeFusion_m_size: int = 1 # for pipefusion
252+
# pipeFusion_n_size: int = 1 # for pipefusion
253+
248254
def __post_init__(self):
249255
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
250256
self._verify_config()
@@ -288,6 +294,14 @@ def _verify_config(self) -> None:
288294
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
289295
self.start_token_size = self.block_size
290296

297+
# check Distrifusion
298+
# TODO(@lry89757) need more detailed check
299+
if self.patched_parallelism_size > 1:
300+
# self.use_patched_parallelism = True
301+
self.tp_size = (
302+
self.patched_parallelism_size
303+
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
304+
291305
# check prompt template
292306
if self.prompt_template is None:
293307
return
@@ -324,6 +338,7 @@ def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
324338
use_cuda_kernel=self.use_cuda_kernel,
325339
use_spec_dec=self.use_spec_dec,
326340
use_flash_attn=use_flash_attn,
341+
patched_parallelism_size=self.patched_parallelism_size,
327342
)
328343
return model_inference_config
329344

@@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
396411
use_cuda_kernel: bool = False
397412
use_spec_dec: bool = False
398413
use_flash_attn: bool = False
414+
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
399415

400416

401417
@dataclass

colossalai/inference/core/diffusion_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from colossalai.accelerator import get_accelerator
1212
from colossalai.cluster import ProcessGroupMesh
1313
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
14-
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
14+
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
1515
from colossalai.inference.modeling.policy import model_policy_map
1616
from colossalai.inference.struct import DiffusionSequence
1717
from colossalai.inference.utils import get_model_size, get_model_type

0 commit comments

Comments
 (0)