Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Deepspeed Zero 3 MiCS support (Issues #20378) #20461

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lig
* Treat your GPU/CPU memory as one large pool. In some cases, you may not want to offload certain things (like activations) to provide even more space to offload model parameters
* When offloading to the CPU, make sure to bump up the batch size as GPU memory will be freed
* We also support sharded checkpointing. By passing ``save_full_weights=False`` to the ``DeepSpeedStrategy``, we'll save shards of the model which allows you to save extremely large models. However to load the model and run test/validation/predict you must use the Trainer object.
* DeepSpeed provides `MiCS support <https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.runtime.zero.config.DeepSpeedZeroConfig.mics_shard_size>`_ which allows you to control how model parameters are sharded across GPUs. This can be useful if you have a large cluster of GPUs and want to avoid communication overhead.
hehepig4 marked this conversation as resolved.
Show resolved Hide resolved

.. _deepspeed-zero-stage-3-single-file:

Expand Down
24 changes: 19 additions & 5 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,25 @@ def module_sharded_context(self) -> AbstractContextManager:
import deepspeed

assert self._config_initialized
return deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
)
assert self.config is not None

if (
"zero_optimization" in self.config
and "mics_shard_size" in self.config["zero_optimization"]
and self.config["zero_optimization"]["mics_shard_size"] > 0
and self.zero_stage_3
):
return deepspeed.zero.MiCS_Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
)
else:
return deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
)

@override
def save_checkpoint(
Expand Down
27 changes: 22 additions & 5 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,29 @@ def model_sharded_context(self) -> Generator[None, None, None]:
import deepspeed

self._init_config_if_needed()
with deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
assert self.config is not None
# If detect 'mics_shard_size'>0 in config['zero_optimization'], alter to use deepspeed.zero.MiCS_Init()
hehepig4 marked this conversation as resolved.
Show resolved Hide resolved
# https://deepspeed.readthedocs.io/en/latest/zero3.html#mics-configurations
#! default deepspeed 0.9.0 is not compatible
if (
"zero_optimization" in self.config
and "mics_shard_size" in self.config["zero_optimization"]
and self.config["zero_optimization"]["mics_shard_size"] > 0
and self.zero_stage_3
):
yield
with deepspeed.zero.MiCS_Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
):
yield
else:
with deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
):
yield

def _set_deepspeed_activation_checkpointing(self) -> None:
import deepspeed
Expand Down
147 changes: 147 additions & 0 deletions tests/tests_fabric/strategies/test_deepspeed_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,150 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init):
zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY)
assert init_mock.call_count == int(not empty_init)
assert model.layer.weight.dtype == torch.bfloat16


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_support():
"""Test to ensure ZeRO Stage 3 MiCS works with a parallel model."""
strategy = DeepSpeedStrategy(stage=3)
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False

fabric = Fabric(
strategy=strategy,
accelerator="cuda",
devices=2,
precision="16-mixed",
)
fabric.launch()

def _make_block():
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

with fabric.init_module():
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)

x = torch.rand(2, 32, device=fabric.device)
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
x = model(x)
x = x.float() # Ensure output is in float32 for softmax operation
logits = F.softmax(x, dim=1)
loss = F.cross_entropy(logits, y)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support():
"""Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support."""
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu")
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False

fabric = Fabric(
strategy=strategy,
accelerator="cuda",
devices=2,
precision="16-mixed",
)
fabric.launch()

def _make_block():
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

with fabric.init_module():
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)

x = torch.rand(2, 32, device=fabric.device)
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
x = model(x)
x = x.float() # Ensure output is in float32 for softmax operation
logits = F.softmax(x, dim=1)
loss = F.cross_entropy(logits, y)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support():
"""Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support."""
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu")
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False

fabric = Fabric(
strategy=strategy,
accelerator="cuda",
devices=2,
precision="16-mixed",
)
fabric.launch()

def _make_block():
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

with fabric.init_module():
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)

x = torch.rand(2, 32, device=fabric.device)
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
x = model(x)
x = x.float() # Ensure output is in float32 for softmax operation
logits = F.softmax(x, dim=1)
loss = F.cross_entropy(logits, y)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()


@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support():
"""Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' =
True)."""
strategy = DeepSpeedStrategy(stage=3)
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 2
strategy.config["zero_optimization"]["offload_param"] = {}
strategy.config["zero_optimization"]["offload_optimizer"] = {}
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True

fabric = Fabric(
strategy=strategy,
accelerator="cuda",
devices=2,
precision="16-mixed",
)
fabric.launch()

def _make_block():
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

with fabric.init_module():
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)

x = torch.rand(2, 32, device=fabric.device)
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
x = model(x)
x = x.float() # Ensure output is in float32 for softmax operation
logits = F.softmax(x, dim=1)
loss = F.cross_entropy(logits, y)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
123 changes: 123 additions & 0 deletions tests/tests_pytorch/strategies/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,3 +1279,126 @@ def test_deepspeed_load_checkpoint_validate_path(tmp_path):
checkpoint_path.touch()
with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"):
strategy.load_checkpoint(checkpoint_path=checkpoint_path)


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_support(tmp_path):
"""Test to ensure we can use DeepSpeed with basic ZeRO Stage 3 MiCS Support."""
model = ModelParallelBoringModel()
strategy = DeepSpeedStrategy(stage=3)
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False

trainer = Trainer(
default_root_dir=tmp_path,
strategy=strategy,
accelerator="gpu",
devices=2,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmp_path, trainer)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert "zero_optimization" in trainer.strategy.config
assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False
assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1
assert trainer.strategy.config["zero_optimization"]["stage"] == 3


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(tmp_path):
"""Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support \
However, in some past pratice, offload param + mics + torchrun will cause inner exception in multi-node environment. \
Probably this exception is caused by torchrun, not deepspeed. """
model = ModelParallelBoringModel()
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu")
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False
trainer = Trainer(
default_root_dir=tmp_path,
strategy=strategy,
accelerator="gpu",
devices=2,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmp_path, trainer)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert "zero_optimization" in trainer.strategy.config
assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False
assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1
assert trainer.strategy.config["zero_optimization"]["stage"] == 3


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(tmp_path):
"""Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support."""
model = ModelParallelBoringModel()
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu")
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False
trainer = Trainer(
default_root_dir=tmp_path,
strategy=strategy,
accelerator="gpu",
devices=2,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmp_path, trainer)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert "zero_optimization" in trainer.strategy.config
assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False
assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1
assert trainer.strategy.config["zero_optimization"]["stage"] == 3


@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(tmp_path):
"""Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' =
True)."""
model = ModelParallelBoringModel()
strategy = DeepSpeedStrategy(stage=3)
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 2
strategy.config["zero_optimization"]["offload_param"] = {}
strategy.config["zero_optimization"]["offload_optimizer"] = {}
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True
# Forming a 2 x 2 hierarchy
trainer = Trainer(
default_root_dir=tmp_path,
strategy=strategy,
accelerator="gpu",
devices=4,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmp_path, trainer)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert "zero_optimization" in trainer.strategy.config
assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is True
assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 2
assert trainer.strategy.config["zero_optimization"]["stage"] == 3
Loading