From 5255984119c4a90c2d1ed0e7c88e0761d8835ddc Mon Sep 17 00:00:00 2001 From: Wang Zhou Date: Mon, 6 Jan 2025 06:32:02 -0800 Subject: [PATCH] explicitly call update_hyper_parameters (#2663) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2663 Due to various wrappers classes, the hyperparameters are not properly updated within the underlying optimizer, especially the ones that are not saved in `param_groups`. Therefore, we'd need to explicitly call the `update_hyper_parameters` method in order to channel the schedule and change the actual values used within the optimizer. Reviewed By: xinzhang-nac Differential Revision: D67804589 fbshipit-source-id: 4b4023b5187dd4783012747059a50fa70b0bda7a --- torchrec/optim/keyed.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/torchrec/optim/keyed.py b/torchrec/optim/keyed.py index a55bf6893..2f6b75d5f 100644 --- a/torchrec/optim/keyed.py +++ b/torchrec/optim/keyed.py @@ -413,6 +413,12 @@ def set_optimizer_step(self, step: int) -> None: # pyre-ignore [16]: Undefined attribute [16]: `KeyedOptimizer` has no attribute `set_optimizer_step`. opt.set_optimizer_step(step) + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + for _, opt in self._optims: + if hasattr(opt, "update_hyper_parameters"): + # pyre-ignore [16]. + opt.update_hyper_parameters(params_dict) + class KeyedOptimizerWrapper(KeyedOptimizer): """ @@ -441,6 +447,11 @@ def set_optimizer_step(self, step: int) -> None: # pyre-ignore [16]. self._optimizer.set_optimizer_step(step) + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + if hasattr(self._optimizer, "update_hyper_parameters"): + # pyre-ignore [16]. + self._optimizer.update_hyper_parameters(params_dict) + class OptimizerWrapper(KeyedOptimizer): """ @@ -493,3 +504,8 @@ def set_optimizer_step(self, step: int) -> None: if hasattr(self._optimizer, "set_optimizer_step"): # pyre-ignore [16]. self._optimizer.set_optimizer_step(step) + + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: + if hasattr(self._optimizer, "update_hyper_parameters"): + # pyre-ignore [16]. + self._optimizer.update_hyper_parameters(params_dict)