Skip to content

Commit d1e7777

Browse files
fix: remove mark-time checking for non-existence of the flag as DeepSpeedEngine propagates flag from the internal model
1 parent 238ba1f commit d1e7777

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

deepspeed/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,18 @@ def _parse_version(version_str):
6868

6969
def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
7070
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
71-
# we shouldn't hit the assert below, but just in case
72-
assert not hasattr(
73-
trainobj, 'ds_is_inited'
74-
), "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once."
71+
if hasattr(trainobj, 'ds_is_inited'):
72+
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
73+
return
74+
7575
trainobj.ds_is_inited = True
7676

7777

7878
def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
7979
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
8080
if hasattr(trainobj, 'ds_is_inited'):
8181
# we shouldn't hit the assert below, but just in case
82-
assert trainobj.ds_is_inited, "Not expecting the model has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
82+
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
8383
return True
8484
return False
8585

tests/unit/runtime/test_ds_initialize.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -445,17 +445,14 @@ def test_no_repeated_init(self):
445445
hidden_dim = 10
446446
model = SimpleModel(hidden_dim)
447447
client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
448-
449-
model = SimpleModel()
450448
# Initialize DeepSpeed configurations for fp16
451449
config_dict = {'train_batch_size': 1}
452450

453-
client_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
454451
# Initialize DeepSpeed engine
455452
_assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None)
456-
model_engine, optim, dataloader, scheduler = deepspeed.initialize(model=model,
457-
optimizer=client_optimizer,
458-
config_params=config_dict)
453+
model_engine, optim, _, _ = deepspeed.initialize(model=model,
454+
optimizer=client_optimizer,
455+
config_params=config_dict)
459456

460457
# arguments should be marked as initialized now
461458
assert _is_initialized(model), "Client model should be marked as initialized"
@@ -464,7 +461,6 @@ def test_no_repeated_init(self):
464461
# return values should also be marked as initialized
465462
assert _is_initialized(model_engine), "Model engine should be marked as initialized"
466463
assert _is_initialized(optim), "Optimizer should be marked as initialized"
467-
assert _is_initialized(scheduler), "Scheduler should be marked as initialized"
468464

469465
exception_raised = False
470466
try:
@@ -473,3 +469,26 @@ def test_no_repeated_init(self):
473469
exception_raised = True
474470

475471
assert exception_raised, "Repeated initialization should raise an exception"
472+
473+
exception_raised = False
474+
try:
475+
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
476+
except ValueError:
477+
exception_raised = True
478+
479+
assert exception_raised, "Initialization on ds types should raise an exception"
480+
481+
exception_raised = False
482+
try:
483+
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)
484+
except ValueError:
485+
exception_raised = True
486+
487+
assert exception_raised, "Initialization on ds types should raise an exception"
488+
489+
exception_raised = False
490+
try:
491+
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
492+
except ValueError:
493+
exception_raised = True
494+
assert exception_raised, "Initialization on ds types should raise an exception"

0 commit comments

Comments
 (0)