From 0a81e6a64f8b70d279f1d18c44711ced76a89a29 Mon Sep 17 00:00:00 2001 From: Corentin-Royer Date: Sat, 2 Nov 2024 18:55:47 +0100 Subject: [PATCH] Fix hyperparameter search when optuna+deepseed --- .../integrations/integration_utils.py | 23 ++++++++----------- src/transformers/trainer.py | 2 +- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index be9a4aff3c7e7f..48928de8200eb1 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -208,7 +208,7 @@ def hp_params(trial): if is_optuna_available(): import optuna - if isinstance(trial, optuna.Trial): + if isinstance(trial, optuna.trial.BaseTrial): return trial.params if is_ray_tune_available(): if isinstance(trial, dict): @@ -230,7 +230,7 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be if trainer.args.process_index == 0: - def _objective(trial, checkpoint_dir=None): + def _objective(trial: optuna.Trial, checkpoint_dir=None): checkpoint = None if checkpoint_dir: for subdir in os.listdir(checkpoint_dir): @@ -240,10 +240,11 @@ def _objective(trial, checkpoint_dir=None): if trainer.args.world_size > 1: if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.") - trainer._hp_search_setup(trial) - args_main_rank_list = [pickle.dumps(trainer.args)] - torch.distributed.broadcast_object_list(args_main_rank_list, src=0) - trainer.train(resume_from_checkpoint=checkpoint) + trainer.hp_space(trial) + fixed_trial = optuna.trial.FixedTrial(trial.params, trial.number) + trial_main_rank_list = [fixed_trial] + torch.distributed.broadcast_object_list(trial_main_rank_list, src=0) + trainer.train(resume_from_checkpoint=checkpoint, trial=trial) else: trainer.train(resume_from_checkpoint=checkpoint, trial=trial) # If there hasn't been any evaluation during the training loop. @@ -268,15 +269,11 @@ def _objective(trial, checkpoint_dir=None): else: for i in range(n_trials): trainer.objective = None - args_main_rank_list = [None] + trial_main_rank_list = [None] if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED: raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.") - torch.distributed.broadcast_object_list(args_main_rank_list, src=0) - args = pickle.loads(bytes(args_main_rank_list[0])) - for key, value in asdict(args).items(): - if key != "local_rank": - setattr(trainer.args, key, value) - trainer.train(resume_from_checkpoint=None) + torch.distributed.broadcast_object_list(trial_main_rank_list, src=0) + trainer.train(resume_from_checkpoint=None, trial=trial_main_rank_list[0]) # If there hasn't been any evaluation during the training loop. if getattr(trainer, "objective", None) is None: metrics = trainer.evaluate() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d41b7181be6334..139bbb6a8b9441 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1748,7 +1748,7 @@ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], ste if self.hp_search_backend == HPSearchBackend.OPTUNA: import optuna - if not trial.study._is_multi_objective(): + if hasattr(trial, "study") and not trial.study._is_multi_objective(): trial.report(self.objective, step) if trial.should_prune(): self.callback_handler.on_train_end(self.args, self.state, self.control)