Skip to content

Commit 7a27f2a

Browse files
Update Ln-norm logic for upcoming PyTorch update (#206)
Co-authored-by: namgyu-youn <[email protected]>
1 parent 8e3bfb5 commit 7a27f2a

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

modelopt/torch/nas/modules/conv.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ def _estimate_importance(self) -> TracedHp.Importance:
139139
return None
140140
weight = self._parameters["weight"] # retrieve full weight tensor
141141
c_in = weight.shape[1]
142-
return torch.norm(torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1)
142+
return torch.linalg.vector_norm(
143+
torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1
144+
)
143145

144146
def _setup(self):
145147
# only support ungrouped conv or grouped conv with in_channels == out_channels
@@ -249,4 +251,4 @@ def _estimate_importance(self) -> TracedHp.Importance:
249251
return None
250252
weight = self._parameters["weight"] # retrieve full weight tensor
251253
c_in = weight.shape[0]
252-
return torch.norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1)
254+
return torch.linalg.vector_norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1)

modelopt/torch/nas/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _get_bias(mod: "_DynamicLinear", bias: torch.Tensor | None) -> torch.Tensor
4141
return get_sliced_tensor(mod, bias, "out_features")
4242

4343
def _estimate_importance(self) -> TracedHp.Importance:
44-
return self._parameters["weight"].detach().norm(dim=0)
44+
return torch.linalg.vector_norm(self._parameters["weight"].detach(), dim=0)
4545

4646
def _setup(self):
4747
# register hyperparameters

modelopt/torch/nas/plugins/megatron.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -613,20 +613,29 @@ def _linear_proj_forward_hook(self, module, input, output):
613613
def _estimate_all_head_importance(self) -> TracedHp.Importance:
614614
"""Return the importance for num_attention_heads (num_heads_per_group * num_query_groups)."""
615615
assert self._activations is not None, "No activations collected for importance estimation."
616-
attn_head_importance = self._activations.view(
617-
self.get_hparam("num_heads_per_group").max * self.get_hparam("num_query_groups").max,
618-
self.config.kv_channels,
619-
).norm(p=2, dim=1)
616+
attn_head_importance = torch.linalg.vector_norm(
617+
self._activations.view(
618+
self.get_hparam("num_heads_per_group").max
619+
* self.get_hparam("num_query_groups").max,
620+
self.config.kv_channels,
621+
),
622+
ord=2,
623+
dim=1,
624+
)
620625
return attn_head_importance
621626

622627
def _estimate_query_group_importance(self) -> TracedHp.Importance:
623628
"""Return the importance of the ``num_query_groups`` hparam."""
624629
assert self._activations is not None, "No activations collected for importance estimation."
625-
group_importance = self._activations.view(
626-
self.get_hparam("num_heads_per_group").max,
627-
self.get_hparam("num_query_groups").max,
628-
self.config.kv_channels,
629-
).norm(p=2, dim=(0, 2))
630+
group_importance = torch.linalg.vector_norm(
631+
self._activations.view(
632+
self.get_hparam("num_heads_per_group").max,
633+
self.get_hparam("num_query_groups").max,
634+
self.config.kv_channels,
635+
),
636+
ord=2,
637+
dim=(0, 2),
638+
)
630639
return group_importance
631640

632641
def export(self) -> torch.nn.Module:

modelopt/torch/nas/plugins/transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def configure_qkv_out(self, q_name: str, k_name: str, v_name: str, out_name: str
122122
out.in_features = hp_hidden_dim
123123

124124
assert isinstance(out, nn.Linear)
125-
hp_hidden_dim.register_importance(lambda: out._parameters["weight"].detach().norm(dim=0))
125+
hp_hidden_dim.register_importance(
126+
lambda: torch.linalg.vector_norm(out._parameters["weight"].detach(), dim=0)
127+
)
126128

127129
def modify(
128130
self, *, n_heads_ratio: tuple[float, ...] | None = None, n_heads_divisor: int = 1

0 commit comments

Comments
 (0)