Skip to content

Commit e3d9720

Browse files
authored
Replace torch.norm with torch.linalg.vector_norm (#2660)
Replace `torch.norm` with `torch.linalg.vector_norm` for PyTorch future update.
1 parent 045c959 commit e3d9720

File tree

8 files changed

+25
-17
lines changed

8 files changed

+25
-17
lines changed

test/prototype/test_blockwise_triton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_blockwise_quant_dequant(_, N, K, dtype):
4141
x = torch.randn(N, K).cuda()
4242
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
4343
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
44-
error = torch.norm(x - x_reconstructed) / torch.norm(x)
44+
error = torch.linalg.vector_norm(x - x_reconstructed) / torch.linalg.vector_norm(x)
4545
print(f"Relative Error: {error.item():.6f}")
4646

4747
assert error < 0.1, "Quant-Dequant error is too high"
@@ -66,7 +66,7 @@ def test_blockwise_fp8_gemm(M, N, K, dtype):
6666
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
6767
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
6868
C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
69-
error = torch.norm(C - C_q) / torch.norm(C)
69+
error = torch.linalg.vector_norm(C - C_q) / torch.linalg.vector_norm(C)
7070
print(f"Relative Error: {error.item():.6f}")
7171

7272
assert error < 0.1, "Quantize gemm error is too high"

test/prototype/test_quantized_training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def test_int8_mixed_precision_training(self, compile, config, module_swap):
211211

212212
def snr(ref, actual):
213213
error = actual - ref
214-
return 20 * torch.log10(ref.norm() / error.norm())
214+
return 20 * torch.log10(
215+
torch.linalg.vector_norm(ref) / torch.linalg.vector_norm(error)
216+
)
215217

216218
assert snr(outputs_ref, outputs_int8mp) > 20
217219
assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20

torchao/float8/float8_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
144144
x: The original tensor.
145145
y: The tensor to compare to the original tensor.
146146
"""
147-
Ps = torch.norm(x)
148-
Pn = torch.norm(x - y)
147+
Ps = torch.linalg.vector_norm(x)
148+
Pn = torch.linalg.vector_norm(x - y)
149149
return 20 * torch.log10(Ps / Pn)
150150

151151

torchao/prototype/parq/quant/lsbq.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def compute_v_per_channel(p: Tensor, dim: Optional[int] = None, ternary: bool =
7070
r = r.sub(v * binary_sign(r))
7171

7272
# compute least squares error, then select the `v` minimizes it
73-
costs = r.norm(dim=dim)
73+
costs = torch.linalg.vector_norm(r, dim=dim)
7474
indices = costs.argmin(dim=dim, keepdim=True)
7575
v_best = v_cands.gather(1, indices)
7676
return v_best
@@ -196,10 +196,10 @@ def quantize_optimal_2bits(
196196
V1V2.append((v1, v2))
197197
assert len(V1V2) > 0, "LSBQ 2-bit optimal: No solution found."
198198
# find the best solution with least-square quantization error
199-
min_error = p.norm()
199+
min_error = torch.linalg.vector_norm(p)
200200
for v1v2 in V1V2:
201201
r = binary_quant_residue(p, v1v2)
202-
error = r.norm()
202+
error = torch.linalg.vector_norm(r)
203203
if error < min_error:
204204
min_error = error
205205
q = p - r
@@ -244,14 +244,14 @@ def quantize_optimal_ternary(
244244
v_feasible.append(v)
245245
assert len(v_feasible) > 0, "LSBQ ternary optimal: No solution found."
246246
# find the best solution with least-square quantization error
247-
min_error = p.norm()
247+
min_error = torch.linalg.vector_norm(p)
248248
q_best = torch.zeros_like(p)
249249
v_best = torch.zeros_like(v)
250250
for v in v_feasible:
251251
Q = v * torch.tensor([-1.0, 0.0, 1.0], device=p.device)
252252
boundaries = v * torch.tensor([-0.5, 0.5], device=p.device)
253253
q = Q[torch.bucketize(p, boundaries)]
254-
error = torch.linalg.norm(p - q)
254+
error = torch.linalg.vector_norm(p - q)
255255
if error < min_error:
256256
min_error = error
257257
q_best = q

torchao/prototype/quantization/codebook/codebook_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def choose_qparams_codebook(
198198
dim=(-1), keepdim=True
199199
).values # Shape: [*input_size[:-1], num_scale_blocks, 1]
200200
else:
201-
scales = input.norm(
202-
dim=(-1), keepdim=True
201+
scales = torch.linalg.vector_norm(
202+
input, dim=-1, keepdim=True
203203
) # Shape: [*input_size[:-1], num_scale_blocks, 1]
204204
scales = torch.clamp(scales, min=1e-9)
205205

@@ -228,12 +228,14 @@ def _kmeans_greedy_init(data: torch.Tensor, k: int) -> torch.Tensor:
228228
running_min_distances = torch.full(
229229
(data.shape[0],), torch.inf, device=data.device, dtype=data.dtype
230230
)
231-
data_norm_squared = data.norm(p=2, dim=1).square()
231+
data_norm_squared = torch.linalg.vector_norm(data, dim=1).square()
232232

233233
for i in range(k):
234234
clusters[i] = data[running_min_distances.argmax()]
235235
distances_to_cluster_i = (
236-
data_norm_squared - 2 * data @ clusters[i] + clusters[i].norm().square()
236+
data_norm_squared
237+
- 2 * data @ clusters[i]
238+
+ torch.linalg.vector_norm(clusters[i]).square()
237239
)
238240
running_min_distances = torch.minimum(
239241
running_min_distances, distances_to_cluster_i, out=running_min_distances

torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def update_mask(self, module, tensor_name, **kwargs):
4343
)
4444
# take norm over all but first dim
4545
dims = tuple(range(1, weights.dim()))
46-
saliency = weights.norm(dim=dims, p=1)
46+
saliency = torch.linalg.vector_norm(weights, dim=dims, ord=1)
4747

4848
# handle weights in 4 groups
4949
split_size = len(mask) // 4

torchao/prototype/sparsity/pruner/saliency_pruner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
68
from .base_structured_sparsifier import BaseStructuredSparsifier
79

810

@@ -26,7 +28,9 @@ def update_mask(self, module, tensor_name, **kwargs):
2628
raise Exception(
2729
"Structured pruning can only be applied to a 2+dim weight tensor!"
2830
)
29-
saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1)
31+
saliency = -torch.linalg.vector_norm(
32+
weights, dim=tuple(range(1, weights.dim())), ord=1
33+
)
3034
assert saliency.shape == mask.shape
3135

3236
num_to_pick = int(len(mask) * kwargs["sparsity_level"])

torchao/sparsity/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def forward(self, x_orig):
8080
new_axis_list[0], new_axis_list[-1] = new_axis_list[-1], new_axis_list[0]
8181
y = x.permute(new_axis_list)
8282
y = torch.flatten(y, start_dim=1)
83-
norm = torch.norm(y, dim=1) ** 2
83+
norm = torch.linalg.vector_norm(y, dim=1) ** 2
8484

8585
if self.norm.numel() == 0:
8686
self.norm.resize_(norm.shape)

0 commit comments

Comments
 (0)