Skip to content

Commit

Permalink
Fix hyperparameter search when optuna+deepseed
Browse files Browse the repository at this point in the history
  • Loading branch information
Corentin-Royer committed Nov 7, 2024
1 parent 7bbc624 commit 0a81e6a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
23 changes: 10 additions & 13 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def hp_params(trial):
if is_optuna_available():
import optuna

if isinstance(trial, optuna.Trial):
if isinstance(trial, optuna.trial.BaseTrial):
return trial.params
if is_ray_tune_available():
if isinstance(trial, dict):
Expand All @@ -230,7 +230,7 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be

if trainer.args.process_index == 0:

def _objective(trial, checkpoint_dir=None):
def _objective(trial: optuna.Trial, checkpoint_dir=None):
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
Expand All @@ -240,10 +240,11 @@ def _objective(trial, checkpoint_dir=None):
if trainer.args.world_size > 1:
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
trainer._hp_search_setup(trial)
args_main_rank_list = [pickle.dumps(trainer.args)]
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=checkpoint)
trainer.hp_space(trial)
fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number)
trial_main_rank_list = [fixed_trial]
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
else:
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
# If there hasn't been any evaluation during the training loop.
Expand All @@ -268,15 +269,11 @@ def _objective(trial, checkpoint_dir=None):
else:
for i in range(n_trials):
trainer.objective = None
args_main_rank_list = [None]
trial_main_rank_list = [None]
if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
args = pickle.loads(bytes(args_main_rank_list[0]))
for key, value in asdict(args).items():
if key != "local_rank":
setattr(trainer.args, key, value)
trainer.train(resume_from_checkpoint=None)
torch.distributed.broadcast_object_list(trial_main_rank_list, src=0)
trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0])
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,7 +1748,7 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna

if not trial.study._is_multi_objective():
if hasattr(trial, "study") and not trial.study._is_multi_objective():
trial.report(self.objective, step)
if trial.should_prune():
self.callback_handler.on_train_end(self.args, self.state, self.control)
Expand Down

0 comments on commit 0a81e6a

Please sign in to comment.