From 5599a0f4f5c633f00fda34c6031a0861b3c26eeb Mon Sep 17 00:00:00 2001 From: CedricHwong <997630814@qq.com> Date: Tue, 30 Dec 2025 07:30:02 +0000 Subject: [PATCH] feats/gradnas:add score averaging and convergence early stop Signed-off-by: CedricHwong <997630814@qq.com> --- modelopt/torch/prune/gradnas.py | 104 ++++++++++++++-- tests/unit/torch/prune/test_gradnas.py | 161 ++++++++++++++++++++++++- 2 files changed, 252 insertions(+), 13 deletions(-) diff --git a/modelopt/torch/prune/gradnas.py b/modelopt/torch/prune/gradnas.py index 50656f6df..ef3276720 100644 --- a/modelopt/torch/prune/gradnas.py +++ b/modelopt/torch/prune/gradnas.py @@ -79,37 +79,100 @@ class GradientDataManager: """Class for managing gradient data for an hparam.""" - def __init__(self, shape, model, reduce_func=lambda x: x): + def __init__( + self, + shape, + model, + reduce_func=lambda x: x, + *, + average_scores: bool = True, + convergence_tol: float | None = 1e-3, + convergence_patience: int = 5, + convergence_min_updates: int = 10, + ): """Initialize GradientDataManager.""" self.mask = torch.ones(shape, requires_grad=True, device=get_module_device(model)) - self._score = torch.zeros_like(self.mask, requires_grad=False) + self._score_sum = torch.zeros_like(self.mask, requires_grad=False) self._reduce_func = reduce_func + self._average_scores = average_scores + self._num_updates = 0 + self._convergence_tol = convergence_tol + self._convergence_patience = max(convergence_patience, 0) + self._convergence_min_updates = max(convergence_min_updates, 0) + self._convergence_count = 0 + self._prev_avg = None + self._convergence_eps = 1e-12 - # TODO: Implement score averaging and early stopping based on score convergence def process_gradient(self): """Process gradient of the mask.""" - self._score += self.mask.grad.detach().pow(2) + self._score_sum += self.mask.grad.detach().pow(2) + self._num_updates += 1 + self._update_convergence() self.mask.grad = None + def _update_convergence(self) -> None: + if self._convergence_tol is None or self._convergence_patience <= 0: + return + avg_score = self._score_sum / self._num_updates + if self._prev_avg is not None and self._num_updates >= self._convergence_min_updates: + rel_change = (avg_score - self._prev_avg).abs() + rel_change = rel_change / self._prev_avg.abs().clamp_min(self._convergence_eps) + if torch.mean(rel_change).item() < self._convergence_tol: + self._convergence_count += 1 + else: + self._convergence_count = 0 + self._prev_avg = avg_score.detach() + + @property + def is_converged(self) -> bool: + """Whether the score has converged based on relative change.""" + return ( + self._convergence_patience > 0 and self._convergence_count >= self._convergence_patience + ) + @property def score(self): """The score of the hparam based on the stored gradients.""" - return self._reduce_func(self._score) + if self._num_updates == 0: + score = self._score_sum + elif self._average_scores: + score = self._score_sum / self._num_updates + else: + score = self._score_sum + return self._reduce_func(score) def _setup_grad_manager_linear( module: dnn._DynamicLinear, + *, + average_scores: bool, + convergence_tol: float | None, + convergence_patience: int, + convergence_min_updates: int, ) -> tuple[GradientDataManager, RemovableHandle]: def forward_hook(_modelopt_mask, module, input, output): return output * _modelopt_mask - grad_data = GradientDataManager(module.get_hparam("out_features").max, module) + grad_data = GradientDataManager( + module.get_hparam("out_features").max, + module, + average_scores=average_scores, + convergence_tol=convergence_tol, + convergence_patience=convergence_patience, + convergence_min_updates=convergence_min_updates, + ) hook_handle = module.register_forward_hook(partial(forward_hook, grad_data.mask)) return grad_data, hook_handle def _setup_grad_manager_hf_attention( - module: "_DynamicAttention", head_mask_idx: int + module: "_DynamicAttention", + head_mask_idx: int, + *, + average_scores: bool, + convergence_tol: float | None, + convergence_patience: int, + convergence_min_updates: int, ) -> tuple[GradientDataManager, RemovableHandle]: def forward_pre_hook(_modelopt_mask, module, args, kwargs): head_mark_in_args = False @@ -140,6 +203,10 @@ def forward_pre_hook(_modelopt_mask, module, args, kwargs): (module.get_hparam("num_attention_heads").max, 1, 1), module, reduce_func=lambda x: x.squeeze(), + average_scores=average_scores, + convergence_tol=convergence_tol, + convergence_patience=convergence_patience, + convergence_min_updates=convergence_min_updates, ) hook_handle = module.register_forward_pre_hook( partial(forward_pre_hook, grad_data.mask), with_kwargs=True @@ -149,27 +216,29 @@ def forward_pre_hook(_modelopt_mask, module, args, kwargs): def _setup_grad_manager_bert_attention( module: "_DynamicBertAttention", + **kwargs, ) -> tuple[GradientDataManager, RemovableHandle]: # See forward signature here: # https://github.com/huggingface/transformers/blob/b86482/src/transformers/models/bert/modeling_bert.py#L415-L424 - return _setup_grad_manager_hf_attention(module, head_mask_idx=2) + return _setup_grad_manager_hf_attention(module, head_mask_idx=2, **kwargs) def _setup_grad_manager_gptj_attention( module: "_DynamicGPTJAttention", + **kwargs, ) -> tuple[GradientDataManager, RemovableHandle]: # See forward signature here: # https://github.com/huggingface/transformers/blob/0ea42e/src/transformers/models/gptj/modeling_gptj.py#L194-L202 - return _setup_grad_manager_hf_attention(module, head_mask_idx=4) + return _setup_grad_manager_hf_attention(module, head_mask_idx=4, **kwargs) class GradientBinarySearcher(BinarySearcher): """Binary searcher for gradient algorithm.""" SETUP_GRADIENT_FUNC: dict[ - type[DynamicModule], Callable[[DynamicModule], tuple[GradientDataManager, RemovableHandle]] + type[DynamicModule], Callable[..., tuple[GradientDataManager, RemovableHandle]] ] @property @@ -177,6 +246,10 @@ def default_search_config(self) -> SearchConfig: """Get the default config for the searcher.""" config = super().default_search_config config["max_iter_data_loader"] = 128 # Default 50 is not optimal for gradient estimation + config["average_scores"] = True + config["score_convergence_tol"] = 1e-3 + config["score_convergence_patience"] = 5 + config["score_convergence_min_updates"] = 10 return config def before_search(self) -> None: @@ -266,8 +339,13 @@ def collect_func(data): for hp_name in hps_for_grad_calc: module_name = hp_name.rpartition(".")[0] module = self.model.get_submodule(module_name) - hp_grad_data[hp_name], mod_to_hook[hp_name] = ( - GradientBinarySearcher.SETUP_GRADIENT_FUNC[type(module)](module) + setup_func = GradientBinarySearcher.SETUP_GRADIENT_FUNC[type(module)] + hp_grad_data[hp_name], mod_to_hook[hp_name] = setup_func( + module, + average_scores=self.config["average_scores"], + convergence_tol=self.config["score_convergence_tol"], + convergence_patience=self.config["score_convergence_patience"], + convergence_min_updates=self.config["score_convergence_min_updates"], ) device = get_module_device(self.model) @@ -284,6 +362,8 @@ def collect_func(data): for grad_data in hp_grad_data.values(): grad_data.process_gradient() + if all(grad_data.is_converged for grad_data in hp_grad_data.values()): + break if idx >= max_iter_data_loader: break diff --git a/tests/unit/torch/prune/test_gradnas.py b/tests/unit/torch/prune/test_gradnas.py index 5641449df..f6c6333e5 100644 --- a/tests/unit/torch/prune/test_gradnas.py +++ b/tests/unit/torch/prune/test_gradnas.py @@ -21,8 +21,9 @@ import torch.nn.functional as F import modelopt.torch.nas as mtn +from modelopt.torch.nas.registry import DMRegistry from modelopt.torch.opt.utils import named_hparams -from modelopt.torch.prune.gradnas import GradientBinarySearcher +from modelopt.torch.prune.gradnas import GradientBinarySearcher, _setup_grad_manager_linear try: from _test_utils.torch.deploy.runtime import FAKE_DEPLOYMENT, fake_latency @@ -114,3 +115,161 @@ def loss_func(x, batch): assert torch.all( torch.sort(hparam.score_tensor, descending=True)[0] == hparam.score_tensor ) + + +class _CountingDataLoader: + def __init__(self, batch, max_batches): + self._batch = batch + self._max_batches = max_batches + self.num_batches = 0 + + def __iter__(self): + for _ in range(self._max_batches): + self.num_batches += 1 + yield self._batch + + +def _make_gradnas_model(): + model = nn.Sequential(nn.Linear(4, 16, bias=False), nn.Linear(16, 8, bias=False)) + with torch.no_grad(): + for layer in model: + layer.weight.fill_(0.1) + return mtn.convert(model, "gradnas") + + +def _estimate_gradnas_scores( + modelopt_model, + dummy_input, + *, + average_scores, + convergence_tol, + convergence_patience, + convergence_min_updates, + max_batches, + max_iter_data_loader=None, +): + def loss_func(output, _batch): + return output.pow(2).mean() + + data_loader = _CountingDataLoader((dummy_input,), max_batches=max_batches) + searcher = GradientBinarySearcher() + searcher.model = modelopt_model + searcher.config = { + **searcher.default_search_config, + "average_scores": average_scores, + "score_convergence_tol": convergence_tol, + "score_convergence_patience": convergence_patience, + "score_convergence_min_updates": convergence_min_updates, + } + had_setup = hasattr(GradientBinarySearcher, "SETUP_GRADIENT_FUNC") + prev_setup = getattr(GradientBinarySearcher, "SETUP_GRADIENT_FUNC", None) + GradientBinarySearcher.SETUP_GRADIENT_FUNC = {DMRegistry[nn.Linear]: _setup_grad_manager_linear} + try: + hps = searcher._estimate_gradient_scores( + data_loader, + loss_func, + max_iter_data_loader=max_batches + if max_iter_data_loader is None + else max_iter_data_loader, + ) + finally: + if had_setup: + GradientBinarySearcher.SETUP_GRADIENT_FUNC = prev_setup + else: + delattr(GradientBinarySearcher, "SETUP_GRADIENT_FUNC") + return hps, data_loader.num_batches + + +def test_gradnas_score_averaging_and_convergence(use_channel_div_4): + modelopt_model = _make_gradnas_model() + dummy_input = torch.ones(2, 4) + + hps_avg, avg_batches = _estimate_gradnas_scores( + modelopt_model, + dummy_input, + average_scores=True, + convergence_tol=1e-6, + convergence_patience=1, + convergence_min_updates=1, + max_batches=20, + ) + avg_scores = {hp_name: hparam.score_tensor.clone() for hp_name, hparam in hps_avg.items()} + hps_sum, sum_batches = _estimate_gradnas_scores( + modelopt_model, + dummy_input, + average_scores=False, + convergence_tol=1e-6, + convergence_patience=1, + convergence_min_updates=1, + max_batches=20, + ) + + assert avg_batches == 2 + assert sum_batches == 2 + + for hp_name, hparam in hps_sum.items(): + assert torch.allclose( + hparam.score_tensor, + avg_scores[hp_name] * sum_batches, + ) + + +@pytest.mark.parametrize( + ("convergence_patience", "convergence_min_updates"), + [ + (1, 1), + (1, 3), + (2, 1), + (2, 3), + ], +) +def test_gradnas_convergence_patience_and_min_updates( + use_channel_div_4, + convergence_patience, + convergence_min_updates, +): + modelopt_model = _make_gradnas_model() + dummy_input = torch.ones(2, 4) + max_batches = 20 + + _, num_batches = _estimate_gradnas_scores( + modelopt_model, + dummy_input, + average_scores=True, + convergence_tol=1e-6, + convergence_patience=convergence_patience, + convergence_min_updates=convergence_min_updates, + max_batches=max_batches, + ) + + expected_batches = max(convergence_min_updates, 2) + convergence_patience - 1 + assert num_batches == expected_batches + + +@pytest.mark.parametrize( + ("convergence_tol", "convergence_patience"), + [ + (None, 1), + (1e-6, 0), + ], +) +def test_gradnas_convergence_disabled_runs_full_loader( + use_channel_div_4, + convergence_tol, + convergence_patience, +): + modelopt_model = _make_gradnas_model() + dummy_input = torch.ones(2, 4) + max_batches = 5 + + _, num_batches = _estimate_gradnas_scores( + modelopt_model, + dummy_input, + average_scores=True, + convergence_tol=convergence_tol, + convergence_patience=convergence_patience, + convergence_min_updates=1, + max_batches=max_batches, + ) + + assert num_batches == max_batches