2121import torch .nn .functional as F
2222from typing import Any , Callable , Union , Dict , Optional
2323
24- from torchao .dtypes import PlainLayoutType
24+ from torchao .dtypes .uintx .Uintx import UintxLayoutType
25+ from torchao .dtypes import (
26+ to_affine_quantized ,
27+ TensorCoreTiledLayoutType ,
28+ PlainLayoutType ,
29+ AffineQuantizedTensor ,
30+ SemiSparseLayoutType
31+ )
2532from torchao .utils import (
2633 TORCH_VERSION_AFTER_2_4 ,
2734 unwrap_tensor_subclass ,
@@ -182,9 +189,6 @@ def _replace_with_custom_fn_if_matches_filter(
182189
183190
184191def _is_linear (mod , * args ):
185- # avoid circular dep
186- from torchao .dtypes import AffineQuantizedTensor
187-
188192 # adding weight tensor subclass isinstance check to make sure the weight is only quantized once
189193 # when it is shared by multiple linear modules
190194 return (
@@ -328,9 +332,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
328332 )
329333
330334def _int8_asymm_per_token_quant (x : torch .Tensor ) -> torch .Tensor :
331- # avoid circular dep
332- from torchao .dtypes import to_affine_quantized
333-
334335 mapping_type = MappingType .ASYMMETRIC
335336 target_dtype = torch .int8
336337 return to_affine_quantized (x , mapping_type , _get_per_token_block_size (x ), target_dtype )
@@ -339,9 +340,6 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
339340 if weight .shape [- 1 ] % group_size != 0 :
340341 return weight
341342
342- # avoid circular dep
343- from torchao .dtypes import to_affine_quantized
344-
345343 # weight settings
346344 mapping_type = MappingType .SYMMETRIC
347345 block_size = (1 , group_size )
@@ -373,7 +371,7 @@ def insert_subclass(lin):
373371 return insert_subclass
374372
375373
376- def int4_weight_only (group_size = 128 , inner_k_tiles = 8 ):
374+ def int4_weight_only (group_size = 128 , layout_type = TensorCoreTiledLayoutType ( inner_k_tiles = 8 ) ):
377375 """
378376 Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
379377 "tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -389,16 +387,12 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
389387 Args:
390388 `group_size`: parameter for quantization, controls the granularity of quantization, smaller
391389 size is more fine grained, choices are [256, 128, 64, 32]
392- `inner_k_tiles `: parameter for int4 mm kernel, choices are [8, 4, 2]
390+ `layout_type `: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
393391 """
394392 def apply_int4_weight_only_quant (weight ):
395393 if weight .shape [- 1 ] % group_size != 0 :
396394 return weight
397395
398- # avoid circular dep
399- from torchao .dtypes import to_affine_quantized
400- from torchao .dtypes import TensorCoreTiledLayoutType
401-
402396 mapping_type = MappingType .ASYMMETRIC
403397 block_size = (1 , group_size )
404398 target_dtype = torch .int32
@@ -408,7 +402,6 @@ def apply_int4_weight_only_quant(weight):
408402 preserve_zero = False
409403 zero_point_dtype = torch .bfloat16
410404 zero_point_domain = ZeroPointDomain .FLOAT
411- layout_type = TensorCoreTiledLayoutType (inner_k_tiles = inner_k_tiles )
412405 return to_affine_quantized (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain , layout_type = layout_type )
413406
414407 return _get_linear_subclass_inserter (apply_int4_weight_only_quant )
@@ -419,9 +412,6 @@ def int8_weight_only():
419412 Applies int8 weight-only symmetric per-channel quantization to linear layers.
420413 """
421414 def apply_int8wo_quant (weight ):
422- # avoid circular dep
423- from torchao .dtypes import to_affine_quantized
424-
425415 mapping_type = MappingType .SYMMETRIC
426416 target_dtype = torch .int8
427417 eps = torch .finfo (torch .float32 ).eps
@@ -432,8 +422,6 @@ def apply_int8wo_quant(weight):
432422 return _get_linear_subclass_inserter (apply_int8wo_quant )
433423
434424def _int8_symm_per_token_reduced_range_quant (x : torch .Tensor ) -> torch .Tensor :
435- # avoid circular dep
436- from torchao .dtypes import to_affine_quantized
437425 mapping_type = MappingType .SYMMETRIC
438426 target_dtype = torch .int8
439427 eps = 1e-5
@@ -453,8 +441,6 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
453441 if in_features <= 16 :
454442 return weight
455443
456- # avoid circular dep
457- from torchao .dtypes import to_affine_quantized
458444 # weight settings
459445 mapping_type = MappingType .SYMMETRIC
460446 def get_weight_block_size (x ):
@@ -479,7 +465,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
479465 Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
480466 quantization + 2:4 sparsity to linear layers.
481467 """
482- from torchao .dtypes import SemiSparseLayoutType
483468 return int8_dynamic_activation_int8_weight (layout_type = SemiSparseLayoutType ())
484469
485470
@@ -495,8 +480,6 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
495480 quantize_affine ,
496481 dequantize_affine ,
497482 )
498- from torchao .dtypes .uintx .Uintx import UintxLayoutType
499- from torchao .dtypes import to_affine_quantized
500483 from torchao .quantization .quant_api import _get_linear_subclass_inserter
501484 def apply_uintx_weight_only_quant (weight ):
502485
0 commit comments