Skip to content
6 changes: 4 additions & 2 deletions modelopt/torch/nas/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def _estimate_importance(self) -> TracedHp.Importance:
return None
weight = self._parameters["weight"] # retrieve full weight tensor
c_in = weight.shape[1]
return torch.norm(torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1)
return torch.linalg.vector_norm(
torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1
)

def _setup(self):
# only support ungrouped conv or grouped conv with in_channels == out_channels
Expand Down Expand Up @@ -249,4 +251,4 @@ def _estimate_importance(self) -> TracedHp.Importance:
return None
weight = self._parameters["weight"] # retrieve full weight tensor
c_in = weight.shape[0]
return torch.norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1)
return torch.linalg.vector_norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1)
2 changes: 1 addition & 1 deletion modelopt/torch/nas/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _get_bias(mod: "_DynamicLinear", bias: torch.Tensor | None) -> torch.Tensor
return get_sliced_tensor(mod, bias, "out_features")

def _estimate_importance(self) -> TracedHp.Importance:
return self._parameters["weight"].detach().norm(dim=0)
return torch.linalg.vector_norm(self._parameters["weight"].detach(), dim=0)

def _setup(self):
# register hyperparameters
Expand Down
16 changes: 10 additions & 6 deletions modelopt/torch/nas/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,17 +616,21 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
attn_head_importance = self._activations.view(
self.get_hparam("num_heads_per_group").max * self.get_hparam("num_query_groups").max,
self.config.kv_channels,
).norm(p=2, dim=1)
).vector_norm(ord=2, dim=1)
return attn_head_importance

def _estimate_query_group_importance(self) -> TracedHp.Importance:
"""Return the importance of the ``num_query_groups`` hparam."""
assert self._activations is not None, "No activations collected for importance estimation."
group_importance = self._activations.view(
self.get_hparam("num_heads_per_group").max,
self.get_hparam("num_query_groups").max,
self.config.kv_channels,
).norm(p=2, dim=(0, 2))
group_importance = torch.linalg.vector_norm(
self._activations.view(
self.get_hparam("num_heads_per_group").max,
self.get_hparam("num_query_groups").max,
self.config.kv_channels,
),
ord=2,
dim=(0, 2),
)
return group_importance

def export(self) -> torch.nn.Module:
Expand Down
4 changes: 3 additions & 1 deletion modelopt/torch/nas/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def configure_qkv_out(self, q_name: str, k_name: str, v_name: str, out_name: str
out.in_features = hp_hidden_dim

assert isinstance(out, nn.Linear)
hp_hidden_dim.register_importance(lambda: out._parameters["weight"].detach().norm(dim=0))
hp_hidden_dim.register_importance(
lambda: torch.linalg.vector_norm(out._parameters["weight"].detach(), dim=0)
)

def modify(
self, *, n_heads_ratio: tuple[float, ...] | None = None, n_heads_divisor: int = 1
Expand Down