Skip to content

Commit bb630db

Browse files
committed
Replaces torch.linalg.norm with torch.linalg.vector_norm for vector-norm
`torch.linalg.norm` supports various calculations based on `dim` parameter: - If dim is an int, the vector norm will be computed. - If dim is a 2-tuple, the matrix norm will be computed. - If dim=None and ord=None, A will be flattened to 1D and the 2-norm of the resulting vector will be computed. - If dim=None and ord!=None, A must be 1D or 2D. Therefore, vector norm is not computed when `dim` is tuple. (Nit: `torch.linalg.vector_norm` is more explicit for vector norms.)
1 parent 0d02d18 commit bb630db

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

modelopt/torch/nas/plugins/megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
622622
def _estimate_query_group_importance(self) -> TracedHp.Importance:
623623
"""Return the importance of the ``num_query_groups`` hparam."""
624624
assert self._activations is not None, "No activations collected for importance estimation."
625-
group_importance = torch.linalg.norm(
625+
group_importance = torch.linalg.vector_norm(
626626
self._activations.view(
627627
self.get_hparam("num_heads_per_group").max,
628628
self.get_hparam("num_query_groups").max,

0 commit comments

Comments
 (0)