@@ -70,7 +70,7 @@ def compute_v_per_channel(p: Tensor, dim: Optional[int] = None, ternary: bool =
7070 r = r .sub (v * binary_sign (r ))
7171
7272 # compute least squares error, then select the `v` minimizes it
73- costs = r . norm ( dim = dim )
73+ costs = torch . linalg . vector_norm ( r , dim = dim )
7474 indices = costs .argmin (dim = dim , keepdim = True )
7575 v_best = v_cands .gather (1 , indices )
7676 return v_best
@@ -196,10 +196,10 @@ def quantize_optimal_2bits(
196196 V1V2 .append ((v1 , v2 ))
197197 assert len (V1V2 ) > 0 , "LSBQ 2-bit optimal: No solution found."
198198 # find the best solution with least-square quantization error
199- min_error = p . norm ( )
199+ min_error = torch . linalg . vector_norm ( p )
200200 for v1v2 in V1V2 :
201201 r = binary_quant_residue (p , v1v2 )
202- error = r . norm ( )
202+ error = torch . linalg . vector_norm ( r )
203203 if error < min_error :
204204 min_error = error
205205 q = p - r
@@ -244,14 +244,14 @@ def quantize_optimal_ternary(
244244 v_feasible .append (v )
245245 assert len (v_feasible ) > 0 , "LSBQ ternary optimal: No solution found."
246246 # find the best solution with least-square quantization error
247- min_error = p . norm ( )
247+ min_error = torch . linalg . vector_norm ( p )
248248 q_best = torch .zeros_like (p )
249249 v_best = torch .zeros_like (v )
250250 for v in v_feasible :
251251 Q = v * torch .tensor ([- 1.0 , 0.0 , 1.0 ], device = p .device )
252252 boundaries = v * torch .tensor ([- 0.5 , 0.5 ], device = p .device )
253253 q = Q [torch .bucketize (p , boundaries )]
254- error = torch .linalg .norm (p - q )
254+ error = torch .linalg .vector_norm (p - q )
255255 if error < min_error :
256256 min_error = error
257257 q_best = q
0 commit comments