Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 92 additions & 12 deletions modelopt/torch/prune/gradnas.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,37 +79,100 @@
class GradientDataManager:
"""Class for managing gradient data for an hparam."""

def __init__(self, shape, model, reduce_func=lambda x: x):
def __init__(
self,
shape,
model,
reduce_func=lambda x: x,
*,
average_scores: bool = True,
convergence_tol: float | None = 1e-3,
convergence_patience: int = 5,
convergence_min_updates: int = 10,
):
"""Initialize GradientDataManager."""
self.mask = torch.ones(shape, requires_grad=True, device=get_module_device(model))
self._score = torch.zeros_like(self.mask, requires_grad=False)
self._score_sum = torch.zeros_like(self.mask, requires_grad=False)
self._reduce_func = reduce_func
self._average_scores = average_scores
self._num_updates = 0
self._convergence_tol = convergence_tol
self._convergence_patience = max(convergence_patience, 0)
self._convergence_min_updates = max(convergence_min_updates, 0)
self._convergence_count = 0
self._prev_avg = None
self._convergence_eps = 1e-12

# TODO: Implement score averaging and early stopping based on score convergence
def process_gradient(self):
"""Process gradient of the mask."""
self._score += self.mask.grad.detach().pow(2)
self._score_sum += self.mask.grad.detach().pow(2)
self._num_updates += 1
self._update_convergence()
self.mask.grad = None

def _update_convergence(self) -> None:
if self._convergence_tol is None or self._convergence_patience <= 0:
return
avg_score = self._score_sum / self._num_updates
if self._prev_avg is not None and self._num_updates >= self._convergence_min_updates:
rel_change = (avg_score - self._prev_avg).abs()
rel_change = rel_change / self._prev_avg.abs().clamp_min(self._convergence_eps)
if torch.mean(rel_change).item() < self._convergence_tol:
self._convergence_count += 1
else:
self._convergence_count = 0
self._prev_avg = avg_score.detach()

@property
def is_converged(self) -> bool:
"""Whether the score has converged based on relative change."""
return (
self._convergence_patience > 0 and self._convergence_count >= self._convergence_patience
)

@property
def score(self):
"""The score of the hparam based on the stored gradients."""
return self._reduce_func(self._score)
if self._num_updates == 0:
score = self._score_sum
elif self._average_scores:
score = self._score_sum / self._num_updates
else:
score = self._score_sum
return self._reduce_func(score)


def _setup_grad_manager_linear(
module: dnn._DynamicLinear,
*,
average_scores: bool,
convergence_tol: float | None,
convergence_patience: int,
convergence_min_updates: int,
) -> tuple[GradientDataManager, RemovableHandle]:
def forward_hook(_modelopt_mask, module, input, output):
return output * _modelopt_mask

grad_data = GradientDataManager(module.get_hparam("out_features").max, module)
grad_data = GradientDataManager(
module.get_hparam("out_features").max,
module,
average_scores=average_scores,
convergence_tol=convergence_tol,
convergence_patience=convergence_patience,
convergence_min_updates=convergence_min_updates,
)
hook_handle = module.register_forward_hook(partial(forward_hook, grad_data.mask))
return grad_data, hook_handle


def _setup_grad_manager_hf_attention(
module: "_DynamicAttention", head_mask_idx: int
module: "_DynamicAttention",
head_mask_idx: int,
*,
average_scores: bool,
convergence_tol: float | None,
convergence_patience: int,
convergence_min_updates: int,
) -> tuple[GradientDataManager, RemovableHandle]:
def forward_pre_hook(_modelopt_mask, module, args, kwargs):
head_mark_in_args = False
Expand Down Expand Up @@ -140,6 +203,10 @@ def forward_pre_hook(_modelopt_mask, module, args, kwargs):
(module.get_hparam("num_attention_heads").max, 1, 1),
module,
reduce_func=lambda x: x.squeeze(),
average_scores=average_scores,
convergence_tol=convergence_tol,
convergence_patience=convergence_patience,
convergence_min_updates=convergence_min_updates,
)
hook_handle = module.register_forward_pre_hook(
partial(forward_pre_hook, grad_data.mask), with_kwargs=True
Expand All @@ -149,34 +216,40 @@ def forward_pre_hook(_modelopt_mask, module, args, kwargs):

def _setup_grad_manager_bert_attention(
module: "_DynamicBertAttention",
**kwargs,
) -> tuple[GradientDataManager, RemovableHandle]:
# See forward signature here:
# https://github.com/huggingface/transformers/blob/b86482/src/transformers/models/bert/modeling_bert.py#L415-L424

return _setup_grad_manager_hf_attention(module, head_mask_idx=2)
return _setup_grad_manager_hf_attention(module, head_mask_idx=2, **kwargs)


def _setup_grad_manager_gptj_attention(
module: "_DynamicGPTJAttention",
**kwargs,
) -> tuple[GradientDataManager, RemovableHandle]:
# See forward signature here:
# https://github.com/huggingface/transformers/blob/0ea42e/src/transformers/models/gptj/modeling_gptj.py#L194-L202

return _setup_grad_manager_hf_attention(module, head_mask_idx=4)
return _setup_grad_manager_hf_attention(module, head_mask_idx=4, **kwargs)


class GradientBinarySearcher(BinarySearcher):
"""Binary searcher for gradient algorithm."""

SETUP_GRADIENT_FUNC: dict[
type[DynamicModule], Callable[[DynamicModule], tuple[GradientDataManager, RemovableHandle]]
type[DynamicModule], Callable[..., tuple[GradientDataManager, RemovableHandle]]
]

@property
def default_search_config(self) -> SearchConfig:
"""Get the default config for the searcher."""
config = super().default_search_config
config["max_iter_data_loader"] = 128 # Default 50 is not optimal for gradient estimation
config["average_scores"] = True
config["score_convergence_tol"] = 1e-3
config["score_convergence_patience"] = 5
config["score_convergence_min_updates"] = 10
return config

def before_search(self) -> None:
Expand Down Expand Up @@ -266,8 +339,13 @@ def collect_func(data):
for hp_name in hps_for_grad_calc:
module_name = hp_name.rpartition(".")[0]
module = self.model.get_submodule(module_name)
hp_grad_data[hp_name], mod_to_hook[hp_name] = (
GradientBinarySearcher.SETUP_GRADIENT_FUNC[type(module)](module)
setup_func = GradientBinarySearcher.SETUP_GRADIENT_FUNC[type(module)]
hp_grad_data[hp_name], mod_to_hook[hp_name] = setup_func(
module,
average_scores=self.config["average_scores"],
convergence_tol=self.config["score_convergence_tol"],
convergence_patience=self.config["score_convergence_patience"],
convergence_min_updates=self.config["score_convergence_min_updates"],
)

device = get_module_device(self.model)
Expand All @@ -284,6 +362,8 @@ def collect_func(data):
for grad_data in hp_grad_data.values():
grad_data.process_gradient()

if all(grad_data.is_converged for grad_data in hp_grad_data.values()):
break
if idx >= max_iter_data_loader:
break

Expand Down
161 changes: 160 additions & 1 deletion tests/unit/torch/prune/test_gradnas.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import torch.nn.functional as F

import modelopt.torch.nas as mtn
from modelopt.torch.nas.registry import DMRegistry
from modelopt.torch.opt.utils import named_hparams
from modelopt.torch.prune.gradnas import GradientBinarySearcher
from modelopt.torch.prune.gradnas import GradientBinarySearcher, _setup_grad_manager_linear

try:
from _test_utils.torch.deploy.runtime import FAKE_DEPLOYMENT, fake_latency
Expand Down Expand Up @@ -114,3 +115,161 @@ def loss_func(x, batch):
assert torch.all(
torch.sort(hparam.score_tensor, descending=True)[0] == hparam.score_tensor
)


class _CountingDataLoader:
def __init__(self, batch, max_batches):
self._batch = batch
self._max_batches = max_batches
self.num_batches = 0

def __iter__(self):
for _ in range(self._max_batches):
self.num_batches += 1
yield self._batch


def _make_gradnas_model():
model = nn.Sequential(nn.Linear(4, 16, bias=False), nn.Linear(16, 8, bias=False))
with torch.no_grad():
for layer in model:
layer.weight.fill_(0.1)
return mtn.convert(model, "gradnas")


def _estimate_gradnas_scores(
modelopt_model,
dummy_input,
*,
average_scores,
convergence_tol,
convergence_patience,
convergence_min_updates,
max_batches,
max_iter_data_loader=None,
):
def loss_func(output, _batch):
return output.pow(2).mean()

data_loader = _CountingDataLoader((dummy_input,), max_batches=max_batches)
searcher = GradientBinarySearcher()
searcher.model = modelopt_model
searcher.config = {
**searcher.default_search_config,
"average_scores": average_scores,
"score_convergence_tol": convergence_tol,
"score_convergence_patience": convergence_patience,
"score_convergence_min_updates": convergence_min_updates,
}
had_setup = hasattr(GradientBinarySearcher, "SETUP_GRADIENT_FUNC")
prev_setup = getattr(GradientBinarySearcher, "SETUP_GRADIENT_FUNC", None)
GradientBinarySearcher.SETUP_GRADIENT_FUNC = {DMRegistry[nn.Linear]: _setup_grad_manager_linear}
try:
hps = searcher._estimate_gradient_scores(
data_loader,
loss_func,
max_iter_data_loader=max_batches
if max_iter_data_loader is None
else max_iter_data_loader,
)
finally:
if had_setup:
GradientBinarySearcher.SETUP_GRADIENT_FUNC = prev_setup
else:
delattr(GradientBinarySearcher, "SETUP_GRADIENT_FUNC")
return hps, data_loader.num_batches


def test_gradnas_score_averaging_and_convergence(use_channel_div_4):
modelopt_model = _make_gradnas_model()
dummy_input = torch.ones(2, 4)

hps_avg, avg_batches = _estimate_gradnas_scores(
modelopt_model,
dummy_input,
average_scores=True,
convergence_tol=1e-6,
convergence_patience=1,
convergence_min_updates=1,
max_batches=20,
)
avg_scores = {hp_name: hparam.score_tensor.clone() for hp_name, hparam in hps_avg.items()}
hps_sum, sum_batches = _estimate_gradnas_scores(
modelopt_model,
dummy_input,
average_scores=False,
convergence_tol=1e-6,
convergence_patience=1,
convergence_min_updates=1,
max_batches=20,
)

assert avg_batches == 2
assert sum_batches == 2

for hp_name, hparam in hps_sum.items():
assert torch.allclose(
hparam.score_tensor,
avg_scores[hp_name] * sum_batches,
)


@pytest.mark.parametrize(
("convergence_patience", "convergence_min_updates"),
[
(1, 1),
(1, 3),
(2, 1),
(2, 3),
],
)
def test_gradnas_convergence_patience_and_min_updates(
use_channel_div_4,
convergence_patience,
convergence_min_updates,
):
modelopt_model = _make_gradnas_model()
dummy_input = torch.ones(2, 4)
max_batches = 20

_, num_batches = _estimate_gradnas_scores(
modelopt_model,
dummy_input,
average_scores=True,
convergence_tol=1e-6,
convergence_patience=convergence_patience,
convergence_min_updates=convergence_min_updates,
max_batches=max_batches,
)

expected_batches = max(convergence_min_updates, 2) + convergence_patience - 1
assert num_batches == expected_batches


@pytest.mark.parametrize(
("convergence_tol", "convergence_patience"),
[
(None, 1),
(1e-6, 0),
],
)
def test_gradnas_convergence_disabled_runs_full_loader(
use_channel_div_4,
convergence_tol,
convergence_patience,
):
modelopt_model = _make_gradnas_model()
dummy_input = torch.ones(2, 4)
max_batches = 5

_, num_batches = _estimate_gradnas_scores(
modelopt_model,
dummy_input,
average_scores=True,
convergence_tol=convergence_tol,
convergence_patience=convergence_patience,
convergence_min_updates=1,
max_batches=max_batches,
)

assert num_batches == max_batches