diff --git a/modelopt/torch/nas/modules/conv.py b/modelopt/torch/nas/modules/conv.py index 4769a46c9..e4b1eebe7 100644 --- a/modelopt/torch/nas/modules/conv.py +++ b/modelopt/torch/nas/modules/conv.py @@ -139,7 +139,9 @@ def _estimate_importance(self) -> TracedHp.Importance: return None weight = self._parameters["weight"] # retrieve full weight tensor c_in = weight.shape[1] - return torch.norm(torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1) + return torch.linalg.vector_norm( + torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1 + ) def _setup(self): # only support ungrouped conv or grouped conv with in_channels == out_channels @@ -249,4 +251,4 @@ def _estimate_importance(self) -> TracedHp.Importance: return None weight = self._parameters["weight"] # retrieve full weight tensor c_in = weight.shape[0] - return torch.norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1) + return torch.linalg.vector_norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1) diff --git a/modelopt/torch/nas/modules/linear.py b/modelopt/torch/nas/modules/linear.py index b82bed68e..b8c171a63 100644 --- a/modelopt/torch/nas/modules/linear.py +++ b/modelopt/torch/nas/modules/linear.py @@ -41,7 +41,7 @@ def _get_bias(mod: "_DynamicLinear", bias: torch.Tensor | None) -> torch.Tensor return get_sliced_tensor(mod, bias, "out_features") def _estimate_importance(self) -> TracedHp.Importance: - return self._parameters["weight"].detach().norm(dim=0) + return torch.linalg.vector_norm(self._parameters["weight"].detach(), dim=0) def _setup(self): # register hyperparameters diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 30eb01507..e65cdcb9b 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -616,17 +616,21 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance: attn_head_importance = self._activations.view( self.get_hparam("num_heads_per_group").max * self.get_hparam("num_query_groups").max, self.config.kv_channels, - ).norm(p=2, dim=1) + ).vector_norm(ord=2, dim=1) return attn_head_importance def _estimate_query_group_importance(self) -> TracedHp.Importance: """Return the importance of the ``num_query_groups`` hparam.""" assert self._activations is not None, "No activations collected for importance estimation." - group_importance = self._activations.view( - self.get_hparam("num_heads_per_group").max, - self.get_hparam("num_query_groups").max, - self.config.kv_channels, - ).norm(p=2, dim=(0, 2)) + group_importance = torch.linalg.vector_norm( + self._activations.view( + self.get_hparam("num_heads_per_group").max, + self.get_hparam("num_query_groups").max, + self.config.kv_channels, + ), + ord=2, + dim=(0, 2), + ) return group_importance def export(self) -> torch.nn.Module: diff --git a/modelopt/torch/nas/plugins/transformers.py b/modelopt/torch/nas/plugins/transformers.py index ad8dcebec..61d5cd5d6 100644 --- a/modelopt/torch/nas/plugins/transformers.py +++ b/modelopt/torch/nas/plugins/transformers.py @@ -122,7 +122,9 @@ def configure_qkv_out(self, q_name: str, k_name: str, v_name: str, out_name: str out.in_features = hp_hidden_dim assert isinstance(out, nn.Linear) - hp_hidden_dim.register_importance(lambda: out._parameters["weight"].detach().norm(dim=0)) + hp_hidden_dim.register_importance( + lambda: torch.linalg.vector_norm(out._parameters["weight"].detach(), dim=0) + ) def modify( self, *, n_heads_ratio: tuple[float, ...] | None = None, n_heads_divisor: int = 1