From b6c0f20b0952e38998f771325b79a10614a119b7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 20 Jun 2025 07:49:28 +0530 Subject: [PATCH 01/12] fix memory address problem --- src/diffusers/hooks/group_offloading.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 1f3a36b5d177..1f3cd6c64046 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib import os from contextlib import contextmanager, nullcontext from typing import Dict, List, Optional, Set, Tuple, Union @@ -62,6 +63,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, + _group_id: Optional[int] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -80,7 +82,9 @@ def __init__( self._is_offloaded_to_disk = False if self.offload_to_disk_path: - self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") + self._group_id = _group_id + short_hash = self._compute_group_hash(self._group_id) + self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") all_tensors = [] for module in self.modules: @@ -260,6 +264,11 @@ def offload_(self): for buffer in self.buffers: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + def _compute_group_hash(self, group_id): + hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() + # first 16 characters for a reasonably short but unique name + return hashed_id[:16] + class GroupOffloadingHook(ModelHook): r""" @@ -603,6 +612,9 @@ def _apply_group_offloading_block_level( for i in range(0, len(submodule), num_blocks_per_group): current_modules = submodule[i : i + num_blocks_per_group] + start_idx = i + end_idx = i + len(current_modules) - 1 + group_id = f"{name}.{start_idx}_to_{end_idx}" group = ModuleGroup( modules=current_modules, offload_device=offload_device, @@ -615,6 +627,7 @@ def _apply_group_offloading_block_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, + _group_id=group_id, ) matched_module_groups.append(group) for j in range(i, i + len(current_modules)): @@ -649,6 +662,7 @@ def _apply_group_offloading_block_level( stream=None, record_stream=False, onload_self=True, + _group_id="top_level_unmatched_modules", ) if stream is None: _apply_group_offloading_hook(module, unmatched_group, None) @@ -715,6 +729,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, + _group_id=name, ) _apply_group_offloading_hook(submodule, group, None) modules_with_group_offloading.add(name) @@ -762,6 +777,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, + _group_id=name, ) _apply_group_offloading_hook(parent_module, group, None) @@ -783,6 +799,7 @@ def _apply_group_offloading_leaf_level( record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, + name="lazy_leafs", ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) From 7c8fc645085caa41d3f0675ca7b52a40cf055bb8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 20 Jun 2025 08:18:16 +0530 Subject: [PATCH 02/12] add more tests --- src/diffusers/hooks/group_offloading.py | 88 +++++++++++++++++++++++++ tests/models/test_modeling_common.py | 28 +++++++- 2 files changed, 113 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 1f3cd6c64046..fded29347fc6 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import glob import hashlib import os from contextlib import contextmanager, nullcontext @@ -907,3 +908,90 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device raise ValueError("Group offloading is not enabled for the provided module.") + + +def _get_expected_safetensors_files( + module: torch.nn.Module, + offload_to_disk_path: str, + offload_type: str, + num_blocks_per_group: Optional[int] = None, +) -> Set[str]: + expected_files = set() + + def get_hashed_filename(group_id: str) -> str: + hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() + short_hash = hashed_id[:16] + return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors") + + if offload_type == "block_level": + if num_blocks_per_group is None: + raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") + + # Handle groups of ModuleList and Sequential blocks + for name, submodule in module.named_children(): + if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + continue + + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] + if not current_modules: + continue + start_idx = i + end_idx = i + len(current_modules) - 1 + group_id = f"{name}.{start_idx}_to_{end_idx}" + expected_files.add(get_hashed_filename(group_id)) + + # Handle the group for unmatched top-level modules and parameters + group_id = "top_level_unmatched_modules" + expected_files.add(get_hashed_filename(group_id)) + + elif offload_type == "leaf_level": + # Handle leaf-level module groups + for name, submodule in module.named_modules(): + if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): + # These groups will always have parameters, so a file is expected + expected_files.add(get_hashed_filename(name)) + + # Handle groups for non-leaf parameters/buffers + modules_with_group_offloading = { + name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS) + } + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) + + all_orphans = parameters + buffers + if all_orphans: + parent_to_tensors = {} + module_dict = dict(module.named_modules()) + for tensor_name, _ in all_orphans: + parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict) + if parent_name not in parent_to_tensors: + parent_to_tensors[parent_name] = [] + parent_to_tensors[parent_name].append(tensor_name) + + for parent_name in parent_to_tensors: + # A file is expected for each parent that gathers orphaned tensors + expected_files.add(get_hashed_filename(parent_name)) + + else: + raise ValueError(f"Unsupported offload_type: {offload_type}") + + return expected_files + + +def _check_safetensors_serialization( + module: torch.nn.Module, + offload_to_disk_path: str, + offload_type: str, + num_blocks_per_group: Optional[int] = None, +) -> bool: + if not os.path.isdir(offload_to_disk_path): + return False, None, None + + expected_files = _get_expected_safetensors_files(module, offload_to_disk_path, offload_type, num_blocks_per_group) + actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) + missing_files = expected_files - actual_files + extra_files = actual_files - expected_files + + is_correct = not missing_files and not extra_files + return is_correct, extra_files, missing_files diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a2ad89101a29..87c484eccb6a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -40,6 +40,7 @@ from parameterized import parameterized from requests.exceptions import HTTPError +from diffusers.hooks.group_offloading import _check_safetensors_serialization from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor, @@ -1697,6 +1698,7 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty @parameterized.expand([(False, "block_level"), (True, "leaf_level")]) @require_torch_accelerator @torch.no_grad() + @torch.inference_mode() def test_group_offloading_with_disk(self, record_stream, offload_type): torch.manual_seed(0) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -1705,11 +1707,15 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): if not getattr(model, "_supports_group_offloading", True): return + model.eval() + output_without_group_offloading = model(**inputs_dict)[0] + torch.manual_seed(0) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.eval() - additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} + num_blocks_per_group = None if offload_type == "leaf_level" else 1 + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group} with tempfile.TemporaryDirectory() as tmpdir: model.enable_group_offload( torch_device, @@ -1720,8 +1726,24 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): **additional_kwargs, ) has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") - assert has_safetensors, "No safetensors found in the directory." - _ = model(**inputs_dict)[0] + # Group offloading with disk support related checks. + self.assertTrue(has_safetensors, "No safetensors found in the directory.") + is_correct, extra_files, missing_files = _check_safetensors_serialization( + module=model, + offload_to_disk_path=tmpdir, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + ) + if not is_correct: + if extra_files: + raise ValueError(f"Found extra files: {', '.join(extra_files)}") + elif missing_files: + raise ValueError(f"Following files are missing: {', '.join(missing_files)}") + + output_with_group_offloading = model(**inputs_dict)[0] + self.assertTrue( + torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4, rtol=1 - 4) + ) def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: From 6639f2530387ced815e1a86d36fcca7be8a0b6f4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 20 Jun 2025 09:03:14 +0530 Subject: [PATCH 03/12] updates --- src/diffusers/hooks/group_offloading.py | 26 +++++++++++++------------ tests/models/test_modeling_common.py | 10 +++++----- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index fded29347fc6..26646b924649 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -37,7 +37,8 @@ _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" - +_GROUP_ID_LAZY_LEAF = "lazy_leafs" +_GROUP_ID_UNMATCHED_GROUP = "top_level_unmatched_modules" _SUPPORTED_PYTORCH_LAYERS = ( torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, @@ -84,7 +85,7 @@ def __init__( if self.offload_to_disk_path: self._group_id = _group_id - short_hash = self._compute_group_hash(self._group_id) + short_hash = _compute_group_hash(self._group_id) self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") all_tensors = [] @@ -265,11 +266,6 @@ def offload_(self): for buffer in self.buffers: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) - def _compute_group_hash(self, group_id): - hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() - # first 16 characters for a reasonably short but unique name - return hashed_id[:16] - class GroupOffloadingHook(ModelHook): r""" @@ -663,7 +659,7 @@ def _apply_group_offloading_block_level( stream=None, record_stream=False, onload_self=True, - _group_id="top_level_unmatched_modules", + _group_id=_GROUP_ID_UNMATCHED_GROUP, ) if stream is None: _apply_group_offloading_hook(module, unmatched_group, None) @@ -800,7 +796,7 @@ def _apply_group_offloading_leaf_level( record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - name="lazy_leafs", + _group_id=_GROUP_ID_LAZY_LEAF, ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) @@ -910,6 +906,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: raise ValueError("Group offloading is not enabled for the provided module.") +def _compute_group_hash(group_id): + hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() + # first 16 characters for a reasonably short but unique name + return hashed_id[:16] + + def _get_expected_safetensors_files( module: torch.nn.Module, offload_to_disk_path: str, @@ -919,8 +921,7 @@ def _get_expected_safetensors_files( expected_files = set() def get_hashed_filename(group_id: str) -> str: - hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() - short_hash = hashed_id[:16] + short_hash = _compute_group_hash(group_id) return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors") if offload_type == "block_level": @@ -942,7 +943,7 @@ def get_hashed_filename(group_id: str) -> str: expected_files.add(get_hashed_filename(group_id)) # Handle the group for unmatched top-level modules and parameters - group_id = "top_level_unmatched_modules" + group_id = _GROUP_ID_UNMATCHED_GROUP expected_files.add(get_hashed_filename(group_id)) elif offload_type == "leaf_level": @@ -972,6 +973,7 @@ def get_hashed_filename(group_id: str) -> str: for parent_name in parent_to_tensors: # A file is expected for each parent that gathers orphaned tensors expected_files.add(get_hashed_filename(parent_name)) + expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF)) else: raise ValueError(f"Unsupported offload_type: {offload_type}") diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 87c484eccb6a..3c2c05f2b6cd 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1346,7 +1346,6 @@ def test_model_parallelism(self): new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) - print(f" new_model.hf_device_map:{new_model.hf_device_map}") self.check_device_map_is_respected(new_model, new_model.hf_device_map) @@ -1708,6 +1707,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): return model.eval() + model.to(torch_device) output_without_group_offloading = model(**inputs_dict)[0] torch.manual_seed(0) @@ -1740,10 +1740,10 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): elif missing_files: raise ValueError(f"Following files are missing: {', '.join(missing_files)}") - output_with_group_offloading = model(**inputs_dict)[0] - self.assertTrue( - torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4, rtol=1 - 4) - ) + output_with_group_offloading = model(**inputs_dict)[0] + self.assertTrue( + torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4, rtol=1e-4) + ) def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: From a9b7abe7bcf8d7d6427c6b4fefeec6ca0194ba15 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 22 Jun 2025 09:58:37 +0530 Subject: [PATCH 04/12] updates --- tests/models/test_modeling_common.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 3c2c05f2b6cd..fcb4b9813792 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1728,17 +1728,21 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") # Group offloading with disk support related checks. self.assertTrue(has_safetensors, "No safetensors found in the directory.") - is_correct, extra_files, missing_files = _check_safetensors_serialization( - module=model, - offload_to_disk_path=tmpdir, - offload_type=offload_type, - num_blocks_per_group=num_blocks_per_group, - ) - if not is_correct: - if extra_files: - raise ValueError(f"Found extra files: {', '.join(extra_files)}") - elif missing_files: - raise ValueError(f"Following files are missing: {', '.join(missing_files)}") + + # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic + # in nature. So, skip it. + if offload_type != "leaf_level": + is_correct, extra_files, missing_files = _check_safetensors_serialization( + module=model, + offload_to_disk_path=tmpdir, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + ) + if not is_correct: + if extra_files: + raise ValueError(f"Found extra files: {', '.join(extra_files)}") + elif missing_files: + raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = model(**inputs_dict)[0] self.assertTrue( From e0bfef988382f802ded5192839b68e61fd1c1e04 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 22 Jun 2025 16:02:26 +0530 Subject: [PATCH 05/12] update --- tests/models/test_modeling_common.py | 33 +++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index fcb4b9813792..429d23a3c18c 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -38,6 +38,7 @@ from huggingface_hub import ModelCard, delete_repo, snapshot_download from huggingface_hub.utils import is_jinja_available from parameterized import parameterized +from pytest import skip from requests.exceptions import HTTPError from diffusers.hooks.group_offloading import _check_safetensors_serialization @@ -1694,13 +1695,28 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) _ = model(**inputs_dict)[0] - @parameterized.expand([(False, "block_level"), (True, "leaf_level")]) + @parameterized.expand([("block_level", False), ("leaf_level", True)]) @require_torch_accelerator @torch.no_grad() @torch.inference_mode() - def test_group_offloading_with_disk(self, record_stream, offload_type): - torch.manual_seed(0) + def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5): + def _has_generator_arg(model): + sig = inspect.signature(model.forward) + params = sig.parameters + return "generator" in params + + def _run_forward(model, inputs_dict): + accepts_generator = _has_generator_arg(model) + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + torch.manual_seed(0) + return model(**inputs_dict)[0] + + if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level": + skip("With `leaf_type` as the offloading type, it fails. Needs investigation.") + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + torch.manual_seed(0) model = self.model_class(**init_dict) if not getattr(model, "_supports_group_offloading", True): @@ -1708,12 +1724,12 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): model.eval() model.to(torch_device) - output_without_group_offloading = model(**inputs_dict)[0] + output_without_group_offloading = _run_forward(model, inputs_dict) torch.manual_seed(0) - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.eval() + num_blocks_per_group = None if offload_type == "leaf_level" else 1 additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group} with tempfile.TemporaryDirectory() as tmpdir: @@ -1726,7 +1742,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): **additional_kwargs, ) has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") - # Group offloading with disk support related checks. self.assertTrue(has_safetensors, "No safetensors found in the directory.") # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic @@ -1744,10 +1759,8 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): elif missing_files: raise ValueError(f"Following files are missing: {', '.join(missing_files)}") - output_with_group_offloading = model(**inputs_dict)[0] - self.assertTrue( - torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4, rtol=1e-4) - ) + output_with_group_offloading = _run_forward(model, inputs_dict) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol)) def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: From e37d2b0305f40c420aaf0b3d36caa8a5f6a92af3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Jun 2025 07:49:52 +0530 Subject: [PATCH 06/12] _group_id = group_id --- src/diffusers/hooks/group_offloading.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 26646b924649..7e26e255377b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -65,7 +65,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, - _group_id: Optional[int] = None, + group_id: Optional[int] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -84,8 +84,8 @@ def __init__( self._is_offloaded_to_disk = False if self.offload_to_disk_path: - self._group_id = _group_id - short_hash = _compute_group_hash(self._group_id) + self.group_id = group_id + short_hash = _compute_group_hash(self.group_id) self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") all_tensors = [] From 99d5ad5ef64f5054ab02ef8c87c30d2e7abc1a7f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Jun 2025 07:50:47 +0530 Subject: [PATCH 07/12] update --- src/diffusers/hooks/group_offloading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 7e26e255377b..95e136090de6 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -84,7 +84,7 @@ def __init__( self._is_offloaded_to_disk = False if self.offload_to_disk_path: - self.group_id = group_id + self.group_id = group_id or id(self) short_hash = _compute_group_hash(self.group_id) self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") From 4f081dc4dc501d981d44155b83a9d1e330ee81eb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 24 Jun 2025 07:52:45 +0530 Subject: [PATCH 08/12] Apply suggestions from code review Co-authored-by: Dhruv Nair --- src/diffusers/hooks/group_offloading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 95e136090de6..f175b99b3362 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -611,7 +611,7 @@ def _apply_group_offloading_block_level( current_modules = submodule[i : i + num_blocks_per_group] start_idx = i end_idx = i + len(current_modules) - 1 - group_id = f"{name}.{start_idx}_to_{end_idx}" + group_id = f"{name}_{i}_{i+len(current_modules)-1}" group = ModuleGroup( modules=current_modules, offload_device=offload_device, From ab2eff7c06f2f3d690842bde14d56e96ad80952a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Jun 2025 07:56:17 +0530 Subject: [PATCH 09/12] update --- src/diffusers/hooks/group_offloading.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f175b99b3362..f9e6eee84f2c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -609,8 +609,6 @@ def _apply_group_offloading_block_level( for i in range(0, len(submodule), num_blocks_per_group): current_modules = submodule[i : i + num_blocks_per_group] - start_idx = i - end_idx = i + len(current_modules) - 1 group_id = f"{name}_{i}_{i+len(current_modules)-1}" group = ModuleGroup( modules=current_modules, From 9710bbc550905930743afd76b65cc6d06a7754c7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Jun 2025 08:08:54 +0530 Subject: [PATCH 10/12] update --- src/diffusers/hooks/group_offloading.py | 103 ++---------------------- src/diffusers/utils/testing_utils.py | 100 ++++++++++++++++++++++- tests/models/test_modeling_common.py | 2 +- 3 files changed, 107 insertions(+), 98 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f9e6eee84f2c..9bc2b30d9b0d 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import glob import hashlib import os from contextlib import contextmanager, nullcontext @@ -37,8 +36,7 @@ _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" -_GROUP_ID_LAZY_LEAF = "lazy_leafs" -_GROUP_ID_UNMATCHED_GROUP = "top_level_unmatched_modules" +GROUP_ID_LAZY_LEAF = "lazy_leafs" _SUPPORTED_PYTORCH_LAYERS = ( torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, @@ -609,7 +607,7 @@ def _apply_group_offloading_block_level( for i in range(0, len(submodule), num_blocks_per_group): current_modules = submodule[i : i + num_blocks_per_group] - group_id = f"{name}_{i}_{i+len(current_modules)-1}" + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" group = ModuleGroup( modules=current_modules, offload_device=offload_device, @@ -622,7 +620,7 @@ def _apply_group_offloading_block_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _group_id=group_id, + group_id=group_id, ) matched_module_groups.append(group) for j in range(i, i + len(current_modules)): @@ -657,7 +655,7 @@ def _apply_group_offloading_block_level( stream=None, record_stream=False, onload_self=True, - _group_id=_GROUP_ID_UNMATCHED_GROUP, + group_id=f"{module.__class__.__name__}_unmatched_group", ) if stream is None: _apply_group_offloading_hook(module, unmatched_group, None) @@ -724,7 +722,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _group_id=name, + group_id=name, ) _apply_group_offloading_hook(submodule, group, None) modules_with_group_offloading.add(name) @@ -772,7 +770,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _group_id=name, + group_id=name, ) _apply_group_offloading_hook(parent_module, group, None) @@ -794,7 +792,7 @@ def _apply_group_offloading_leaf_level( record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _group_id=_GROUP_ID_LAZY_LEAF, + group_id=GROUP_ID_LAZY_LEAF, ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) @@ -908,90 +906,3 @@ def _compute_group_hash(group_id): hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() # first 16 characters for a reasonably short but unique name return hashed_id[:16] - - -def _get_expected_safetensors_files( - module: torch.nn.Module, - offload_to_disk_path: str, - offload_type: str, - num_blocks_per_group: Optional[int] = None, -) -> Set[str]: - expected_files = set() - - def get_hashed_filename(group_id: str) -> str: - short_hash = _compute_group_hash(group_id) - return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors") - - if offload_type == "block_level": - if num_blocks_per_group is None: - raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") - - # Handle groups of ModuleList and Sequential blocks - for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - continue - - for i in range(0, len(submodule), num_blocks_per_group): - current_modules = submodule[i : i + num_blocks_per_group] - if not current_modules: - continue - start_idx = i - end_idx = i + len(current_modules) - 1 - group_id = f"{name}.{start_idx}_to_{end_idx}" - expected_files.add(get_hashed_filename(group_id)) - - # Handle the group for unmatched top-level modules and parameters - group_id = _GROUP_ID_UNMATCHED_GROUP - expected_files.add(get_hashed_filename(group_id)) - - elif offload_type == "leaf_level": - # Handle leaf-level module groups - for name, submodule in module.named_modules(): - if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): - # These groups will always have parameters, so a file is expected - expected_files.add(get_hashed_filename(name)) - - # Handle groups for non-leaf parameters/buffers - modules_with_group_offloading = { - name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS) - } - parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) - buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) - - all_orphans = parameters + buffers - if all_orphans: - parent_to_tensors = {} - module_dict = dict(module.named_modules()) - for tensor_name, _ in all_orphans: - parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict) - if parent_name not in parent_to_tensors: - parent_to_tensors[parent_name] = [] - parent_to_tensors[parent_name].append(tensor_name) - - for parent_name in parent_to_tensors: - # A file is expected for each parent that gathers orphaned tensors - expected_files.add(get_hashed_filename(parent_name)) - expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF)) - - else: - raise ValueError(f"Unsupported offload_type: {offload_type}") - - return expected_files - - -def _check_safetensors_serialization( - module: torch.nn.Module, - offload_to_disk_path: str, - offload_type: str, - num_blocks_per_group: Optional[int] = None, -) -> bool: - if not os.path.isdir(offload_to_disk_path): - return False, None, None - - expected_files = _get_expected_safetensors_files(module, offload_to_disk_path, offload_type, num_blocks_per_group) - actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) - missing_files = expected_files - actual_files - extra_files = actual_files - expected_files - - is_correct = not missing_files and not extra_files - return is_correct, extra_files, missing_files diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index e5da39c1d865..2ff69d818a42 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,4 +1,5 @@ import functools +import glob import importlib import importlib.metadata import inspect @@ -18,7 +19,7 @@ from contextlib import contextmanager from io import BytesIO, StringIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import PIL.Image @@ -1377,6 +1378,103 @@ def get_device_properties() -> DeviceProperties: else: DevicePropertiesUserDict = UserDict +if is_torch_available(): + from diffusers.hooks.group_offloading import ( + _GROUP_ID_LAZY_LEAF, + _SUPPORTED_PYTORCH_LAYERS, + _compute_group_hash, + _find_parent_module_in_module_dict, + _gather_buffers_with_no_group_offloading_parent, + _gather_parameters_with_no_group_offloading_parent, + ) + + def _get_expected_safetensors_files( + module: torch.nn.Module, + offload_to_disk_path: str, + offload_type: str, + num_blocks_per_group: Optional[int] = None, + ) -> Set[str]: + expected_files = set() + + def get_hashed_filename(group_id: str) -> str: + short_hash = _compute_group_hash(group_id) + return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors") + + if offload_type == "block_level": + if num_blocks_per_group is None: + raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") + + # Handle groups of ModuleList and Sequential blocks + unmatched_modules = [] + for name, submodule in module.named_children(): + if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + unmatched_modules.append(module) + continue + + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] + if not current_modules: + continue + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" + expected_files.add(get_hashed_filename(group_id)) + + # Handle the group for unmatched top-level modules and parameters + for module in unmatched_modules: + expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group")) + + elif offload_type == "leaf_level": + # Handle leaf-level module groups + for name, submodule in module.named_modules(): + if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): + # These groups will always have parameters, so a file is expected + expected_files.add(get_hashed_filename(name)) + + # Handle groups for non-leaf parameters/buffers + modules_with_group_offloading = { + name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS) + } + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) + + all_orphans = parameters + buffers + if all_orphans: + parent_to_tensors = {} + module_dict = dict(module.named_modules()) + for tensor_name, _ in all_orphans: + parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict) + if parent_name not in parent_to_tensors: + parent_to_tensors[parent_name] = [] + parent_to_tensors[parent_name].append(tensor_name) + + for parent_name in parent_to_tensors: + # A file is expected for each parent that gathers orphaned tensors + expected_files.add(get_hashed_filename(parent_name)) + expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF)) + + else: + raise ValueError(f"Unsupported offload_type: {offload_type}") + + return expected_files + + def _check_safetensors_serialization( + module: torch.nn.Module, + offload_to_disk_path: str, + offload_type: str, + num_blocks_per_group: Optional[int] = None, + ) -> bool: + if not os.path.isdir(offload_to_disk_path): + return False, None, None + + expected_files = _get_expected_safetensors_files( + module, offload_to_disk_path, offload_type, num_blocks_per_group + ) + actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) + missing_files = expected_files - actual_files + extra_files = actual_files - expected_files + + is_correct = not missing_files and not extra_files + return is_correct, extra_files, missing_files + class Expectations(DevicePropertiesUserDict): def get_expectation(self) -> Any: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 429d23a3c18c..8078fcba661b 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -41,7 +41,6 @@ from pytest import skip from requests.exceptions import HTTPError -from diffusers.hooks.group_offloading import _check_safetensors_serialization from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor, @@ -62,6 +61,7 @@ from diffusers.utils.hub_utils import _add_variant from diffusers.utils.testing_utils import ( CaptureLogger, + _check_safetensors_serialization, backend_empty_cache, backend_max_memory_allocated, backend_reset_peak_memory_stats, From 6901ef4c764a21af80959c3044939e1a116a2af9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Jun 2025 08:12:40 +0530 Subject: [PATCH 11/12] update --- src/diffusers/hooks/group_offloading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 9bc2b30d9b0d..a981e6b74ed4 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -36,7 +36,7 @@ _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" -GROUP_ID_LAZY_LEAF = "lazy_leafs" +_GROUP_ID_LAZY_LEAF = "lazy_leafs" _SUPPORTED_PYTORCH_LAYERS = ( torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, @@ -792,7 +792,7 @@ def _apply_group_offloading_leaf_level( record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - group_id=GROUP_ID_LAZY_LEAF, + group_id=_GROUP_ID_LAZY_LEAF, ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) From e75ef18f545b8d17187ff8b52fbfcd03b9f77244 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Jun 2025 08:32:51 +0530 Subject: [PATCH 12/12] fix --- src/diffusers/hooks/group_offloading.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index a981e6b74ed4..e6d25bffc03b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -82,7 +82,8 @@ def __init__( self._is_offloaded_to_disk = False if self.offload_to_disk_path: - self.group_id = group_id or id(self) + # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. + self.group_id = group_id if group_id is not None else str(id(self)) short_hash = _compute_group_hash(self.group_id) self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")