diff --git a/torchrec/optim/keyed.py b/torchrec/optim/keyed.py index a55bf6893..2f6b75d5f 100644 --- a/torchrec/optim/keyed.py +++ b/torchrec/optim/keyed.py @@ -413,6 +413,12 @@ def set_optimizer_step(self, step: int) -> None: # pyre-ignore [16]: Undefined attribute [16]: `KeyedOptimizer` has no attribute `set_optimizer_step`. opt.set_optimizer_step(step) + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + for _, opt in self._optims: + if hasattr(opt, "update_hyper_parameters"): + # pyre-ignore [16]. + opt.update_hyper_parameters(params_dict) + class KeyedOptimizerWrapper(KeyedOptimizer): """ @@ -441,6 +447,11 @@ def set_optimizer_step(self, step: int) -> None: # pyre-ignore [16]. self._optimizer.set_optimizer_step(step) + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + if hasattr(self._optimizer, "update_hyper_parameters"): + # pyre-ignore [16]. + self._optimizer.update_hyper_parameters(params_dict) + class OptimizerWrapper(KeyedOptimizer): """ @@ -493,3 +504,8 @@ def set_optimizer_step(self, step: int) -> None: if hasattr(self._optimizer, "set_optimizer_step"): # pyre-ignore [16]. self._optimizer.set_optimizer_step(step) + + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + if hasattr(self._optimizer, "update_hyper_parameters"): + # pyre-ignore [16]. + self._optimizer.update_hyper_parameters(params_dict)