@@ -511,7 +511,7 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType.
511511 return _get_linear_subclass_inserter (apply_int8_dynamic_activation_int4_weight_quant , group_size = group_size , mapping_type = mapping_type )
512512
513513
514- def int4_weight_only (group_size = 128 , _layout = TensorCoreTiledLayout (inner_k_tiles = 8 ), use_hqq = False ):
514+ def int4_weight_only (group_size = 128 , layout = TensorCoreTiledLayout (inner_k_tiles = 8 ), use_hqq = False ):
515515 """
516516 Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
517517 "tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -527,7 +527,7 @@ def int4_weight_only(group_size=128, _layout=TensorCoreTiledLayout(inner_k_tiles
527527 Args:
528528 `group_size`: parameter for quantization, controls the granularity of quantization, smaller
529529 size is more fine grained, choices are [256, 128, 64, 32]
530- `_layout `: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
530+ `layout `: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
531531 `use_hqq`: whether to use hqq or default quantization mode, default is False
532532 """
533533 def apply_int4_weight_only_quant (weight ):
@@ -550,12 +550,12 @@ def apply_int4_weight_only_quant(weight):
550550 # Sparse Marlin only supports symmetric quantization.
551551 # NOTE: If we start having lots of layouts that require different configurations,
552552 # we should consider moving this logic somewhere else.
553- if isinstance (_layout , MarlinSparseLayout ):
553+ if isinstance (layout , MarlinSparseLayout ):
554554 mapping_type = MappingType .SYMMETRIC
555555 preserve_zero = True
556556 zero_point_domain = ZeroPointDomain .INT
557557
558- return to_affine_quantized_intx (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain , _layout = _layout , use_hqq = use_hqq )
558+ return to_affine_quantized_intx (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain , _layout = layout , use_hqq = use_hqq )
559559
560560 return _get_linear_subclass_inserter (apply_int4_weight_only_quant )
561561
@@ -583,7 +583,7 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
583583 return to_affine_quantized_intx (x , mapping_type , _get_per_token_block_size (x ), target_dtype , eps = eps , quant_min = quant_min , quant_max = quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
584584
585585
586- def int8_dynamic_activation_int8_weight (_layout = PlainLayout ()):
586+ def int8_dynamic_activation_int8_weight (layout = PlainLayout ()):
587587 """
588588 Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
589589 quantization to linear layers
@@ -609,7 +609,7 @@ def get_weight_block_size(x):
609609 input_quant_func = _int8_symm_per_token_reduced_range_quant
610610
611611 block_size = get_weight_block_size (weight )
612- weight = to_affine_quantized_intx (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype , _layout = _layout )
612+ weight = to_affine_quantized_intx (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype , _layout = layout )
613613 weight = to_linear_activation_quantized (weight , input_quant_func )
614614 return weight
615615
@@ -621,12 +621,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
621621 Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
622622 quantization + 2:4 sparsity to linear layers.
623623 """
624- warnings .warn ("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the _layout kwarg in int8_dynamic_activation_int8_weight instead.
624+ warnings .warn ("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
625625
626626 from torchao.dtypes import SemiSparseLayout
627- int8_dynamic_activation_int8_weight(_layout =SemiSparseLayout()""" )
627+ int8_dynamic_activation_int8_weight(layout =SemiSparseLayout()""" )
628628
629- return int8_dynamic_activation_int8_weight (_layout = SemiSparseLayout ())
629+ return int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ())
630630
631631
632632def float8_weight_only (weight_dtype : torch .dtype = torch .float8_e4m3fn ):
0 commit comments