From e8f1fb1690d444891256b9e221e1ae4666e449a0 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 09:48:39 -0700 Subject: [PATCH 01/20] wip --- scripts/sam/benchmark.sh | 9 +- scripts/sam/eval_combo.py | 24 ++-- scripts/sam/results.csv | 2 + torchao/dtypes/affine_quantized_tensor.py | 116 +++++++++++++++++- .../prototype/dynamic_quant_sparse.py | 44 ++++++- 5 files changed, 175 insertions(+), 20 deletions(-) diff --git a/scripts/sam/benchmark.sh b/scripts/sam/benchmark.sh index 5c1262f9cc..22abd147bb 100755 --- a/scripts/sam/benchmark.sh +++ b/scripts/sam/benchmark.sh @@ -1,11 +1,10 @@ # baseline -python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True +# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True # int8 dynamic quant (all) -python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant +# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant # 2:4 sparsity (all) -python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only +# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only # 2:4 sparsity (mlp only) -python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse +# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse # int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse) python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse - diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index e83ec25300..e65dbf3065 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -282,7 +282,7 @@ def run( from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight from torchao.utils import unwrap_tensor_subclass quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) - predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) + unwrap_tensor_subclass(predictor.model.image_encoder) elif compress == "sparse_mlp_only": def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name @@ -316,20 +316,18 @@ def mlp_only(mod, name): apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - quantize_( - predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(), - attn_only - ) - predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) + quantize_(predictor.model.image_encoder, + int8_dynamic_activation_int8_weight(), + attn_only) + quantize_(predictor.model.image_encoder, + int8_dynamic_activation_int8_2x4_sparse_weight(), + mlp_lin1_only) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, - int8_dynamic_activation_int8_2x4_sparse_weight(), - mlp_lin1_only, prune=False) + predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, - mlp_lin2_only, prune=False) + mlp_lin2_only) else: assert compress is None, f"Unsupported compress mode {compress}" @@ -382,7 +380,7 @@ def mlp_only(mod, name): batch_size, use_compile, use_compile_decoder, - pad_input_image_batch, + pad_input_image_batch, compress) results = [[r[0], r[1], r[2], r[3].item()] for r in results] @@ -413,6 +411,6 @@ def mlp_only(mod, name): vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile, use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path])) f.write(vals+"\n") - + if __name__ == '__main__': fire.Fire(run) diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv index 01aad5c022..7be3a26355 100644 --- a/scripts/sam/results.csv +++ b/scripts/sam/results.csv @@ -4,3 +4,5 @@ cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,m cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15203,18,24.104394829823185,41.48621058773685,0.567190438968895,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15203,18,24.329631814054935,41.10214275508732,0.567190438968895,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 3cde983e9c..2c0684f15c 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -208,7 +208,6 @@ def from_float( input_float, (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) - scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) @@ -216,6 +215,8 @@ def from_float( # TODO: this is temporary, need to come up with the proper UX if extended_layout == "tensor_core_tiled": layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles) + elif extended_layout == "semi_sparse_cusparselt": + layout_tensor = layout_cls_ctr(torch._cslt_compress(int_data), scale, zero_point) else: layout_tensor = layout_cls_ctr(int_data, scale, zero_point) return cls( @@ -410,6 +411,94 @@ def from_plain( ): return cls(int_data, scale, zero_point) +@register_layout_cls("semi_sparse_cusparselt") +class SparseAQTLayout(AQTLayout): + """ + Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor + + It stores int_data in compressed form + """ + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = torch.Size([zero_point.shape[0], + int_data.numel() * 16 // (10 * zero_point.shape[0])]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + return cls(int_data, scale, zero_point) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]), + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + def get_plain(self): + int_data_expanded = torch._cslt_sparse_mm(self.int_data, + torch.eye(self.shape[1], + dtype=self.int_data.dtype, + device=self.int_data.device).t()) + return int_data_expanded, self.scale, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + return cls(int_data, scale, zero_point) + @register_layout_cls("tensor_core_tiled") class TensorCoreTiledAQTLayout(AQTLayout): """ @@ -594,6 +683,31 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): if bias is not None: y += bias return y + + # handle int8 + semi_structured_sparse + elif( + is_cuda and + input_is_int8 and + input_tensor.dtype == weight_qtensor.dtype and + input_tensor.extended_layout == "plain" and + weight_qtensor.extended_layout == "semi_sparse_cusparselt" + ): + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals_int8 = weight_qtensor.layout_tensor.int_data + w_scales = weight_qtensor.layout_tensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16 + ).t() + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + # downcast at the end + y = y.to(output_dtype) + if bias is not None: + y += bias + return y else: input_tensor = input_tensor.dequantize() diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py index 2f2a198278..37d0c42076 100644 --- a/torchao/sparsity/prototype/dynamic_quant_sparse.py +++ b/torchao/sparsity/prototype/dynamic_quant_sparse.py @@ -310,5 +310,47 @@ def from_float(cls, input_float, qmin=-128, qmax=127): dtype=input_float.dtype, ) +from torchao.dtypes import to_affine_quantized +from torchao.quantization.quant_api import MappingType, ZeroPointDomain, to_linear_act_quantized + def int8_dynamic_activation_int8_2x4_sparse_weight(): - return Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float + """ + Applies int8 dynamic symmetric per-token activation and int8 per-channel weight + quantization to linear layers + """ + def apply_int8_dynamic_activation_int8_2x4_sparse_weight_quant(weight): + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + return weight + + # avoid circular dep + from torchao.dtypes import to_affine_quantized + # weight settings + mapping_type = MappingType.SYMMETRIC + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, scale_dtype=torch.float32, zero_point_dtype=zero_point_dtype, extended_layout="semi_sparse_cusparselt") + weight = to_linear_act_quantized(weight, input_quant_func) + return weight + + return apply_int8_dynamic_activation_int8_2x4_sparse_weight_quant From 4e95ebd74b3dc73311abb6a41d0f69c894633da1 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 10:38:03 -0700 Subject: [PATCH 02/20] more --- scripts/sam/results.csv | 4 ++++ torchao/dtypes/affine_quantized_tensor.py | 20 +++++++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv index 7be3a26355..460392c005 100644 --- a/scripts/sam/results.csv +++ b/scripts/sam/results.csv @@ -6,3 +6,7 @@ cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,m cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None cuda,vit_h,32,15203,18,24.104394829823185,41.48621058773685,0.567190438968895,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None cuda,vit_h,32,15203,18,24.329631814054935,41.10214275508732,0.567190438968895,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None + +cuda,vit_h,32,15211,18,24.56229782147879,40.71280330806584,0.5671033117430888,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15212,18,22.001439043708203,45.45157241821289,0.5719013796228187,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,2,64,None,None +cuda,vit_h,32,14870,18,23.586484684427163,42.39716148376465,0.5728743877698744,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,2,64,None,None diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 2c0684f15c..7a4fdd198e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -696,7 +696,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): x_scales = input_tensor.layout_tensor.scale w_vals_int8 = weight_qtensor.layout_tensor.int_data w_scales = weight_qtensor.layout_tensor.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16 ).t() @@ -704,6 +704,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) # downcast at the end + output_dtype = input_tensor.dtype y = y.to(output_dtype) if bias is not None: y += bias @@ -798,14 +799,15 @@ def functional_linear(*args, **kwargs): # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to # make the branches easier to understand in `_quantized_linear_op` - try: - return _quantized_linear_op(input_tensor, weight_tensor, bias) - except: - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, AffineQuantizedTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + return _quantized_linear_op(input_tensor, weight_tensor, bias) + # try: + # except: + # pass + # # if isinstance(input_tensor, AffineQuantizedTensor): + # # input_tensor = input_tensor.dequantize() + # # if isinstance(weight_tensor, AffineQuantizedTensor): + # # weight_tensor = weight_tensor.dequantize() + # # return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @implements([aten.mm.default, aten.addmm.default]) def aten_mm(func, *args, **kwargs): From f88873c9c208a09f6dc7a85ed5dbd6cb6b183607 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 13:26:23 -0700 Subject: [PATCH 03/20] working --- scripts/sam/benchmark.sh | 1 + scripts/sam/eval_combo.py | 1 - scripts/sam/results.csv | 13 +++++ torchao/dtypes/affine_quantized_tensor.py | 56 ++----------------- .../prototype/dynamic_quant_sparse.py | 2 +- 5 files changed, 20 insertions(+), 53 deletions(-) diff --git a/scripts/sam/benchmark.sh b/scripts/sam/benchmark.sh index 22abd147bb..54bc8ef0c0 100755 --- a/scripts/sam/benchmark.sh +++ b/scripts/sam/benchmark.sh @@ -7,4 +7,5 @@ # 2:4 sparsity (mlp only) # python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse # int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse) +# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse_other python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index e65dbf3065..df10b10c73 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -297,7 +297,6 @@ def mlp_only(mod, name): predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured) elif compress == "int8_dynamic_quant_sparse": from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - SparseSemiStructuredTensor._FORCE_CUTLASS = False from torchao.sparsity import sparsify, apply_fake_sparsity from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv index 460392c005..21b29d1dba 100644 --- a/scripts/sam/results.csv +++ b/scripts/sam/results.csv @@ -10,3 +10,16 @@ cuda,vit_h,32,15203,18,24.329631814054935,41.10214275508732,0.567190438968895,ma cuda,vit_h,32,15211,18,24.56229782147879,40.71280330806584,0.5671033117430888,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None cuda,vit_h,32,15212,18,22.001439043708203,45.45157241821289,0.5719013796228187,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,2,64,None,None cuda,vit_h,32,14870,18,23.586484684427163,42.39716148376465,0.5728743877698744,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,2,64,None,None +cuda,vit_h,32,14870,18,26.080066556091367,38.343460429798476,0.567100022987887,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14870,18,24.105935759314356,41.483558654785156,0.5728743877698744,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,2,64,None,None +cuda,vit_h,32,14870,18,26.077327811425214,38.347487412489855,0.567100022987887,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,25.878415158662072,38.64224272888976,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14868,18,26.405483383382357,37.870921939998475,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse_other,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,26.405490913715926,37.87091113994648,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,26.638093985525533,37.54022343127758,0.06075761447241241,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,26.308018244576076,38.011224969640956,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,12441,15,26.475001207469894,37.77148080801035,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,25.998585036559202,38.463631716641515,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,25.5705826577316,39.10743894205465,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,22.687728699962957,44.076690673828125,0.5769663374200948,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,4,128,None,None +cuda,vit_h,32,14869,18,26.326282787092705,37.98485369496531,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 7a4fdd198e..aa4a845e2e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -412,11 +412,9 @@ def from_plain( return cls(int_data, scale, zero_point) @register_layout_cls("semi_sparse_cusparselt") -class SparseAQTLayout(AQTLayout): +class SparseAQTLayout(PlainAQTLayout): """ Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor - - It stores int_data in compressed form """ def __new__( cls, @@ -435,41 +433,6 @@ def __new__( int_data.numel() * 16 // (10 * zero_point.shape[0])]) return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - - def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - return cls(int_data, scale, zero_point) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.scale), - fn(self.zero_point), - ) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -490,15 +453,6 @@ def get_plain(self): device=self.int_data.device).t()) return int_data_expanded, self.scale, self.zero_point - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - ): - return cls(int_data, scale, zero_point) - @register_layout_cls("tensor_core_tiled") class TensorCoreTiledAQTLayout(AQTLayout): """ @@ -684,7 +638,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): y += bias return y - # handle int8 + semi_structured_sparse + # handle int8 dynamic_quant + semi_structured_sparse elif( is_cuda and input_is_int8 and @@ -696,14 +650,14 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): x_scales = input_tensor.layout_tensor.scale w_vals_int8 = weight_qtensor.layout_tensor.int_data w_scales = weight_qtensor.layout_tensor.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales, out_dtype=torch.bfloat16 + w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 ).t() y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) - # downcast at the end output_dtype = input_tensor.dtype y = y.to(output_dtype) if bias is not None: diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py index 37d0c42076..d0d80a6b79 100644 --- a/torchao/sparsity/prototype/dynamic_quant_sparse.py +++ b/torchao/sparsity/prototype/dynamic_quant_sparse.py @@ -349,7 +349,7 @@ def get_per_token_block_size(x): input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, scale_dtype=torch.float32, zero_point_dtype=zero_point_dtype, extended_layout="semi_sparse_cusparselt") + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, extended_layout="semi_sparse_cusparselt") weight = to_linear_act_quantized(weight, input_quant_func) return weight From 44dadfcb0e8841f821bdcb3ffe888ff995918629 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 13:44:25 -0700 Subject: [PATCH 04/20] move from prototype -> quant api --- scripts/sam/eval_combo.py | 13 +- scripts/sam/results.csv | 2 + torchao/quantization/__init__.py | 1 + torchao/quantization/quant_api.py | 75 ++-- torchao/sparsity/__init__.py | 3 +- .../prototype/dynamic_quant_sparse.py | 356 ------------------ torchao/sparsity/sparse_api.py | 1 + 7 files changed, 54 insertions(+), 397 deletions(-) delete mode 100644 torchao/sparsity/prototype/dynamic_quant_sparse.py diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index df10b10c73..7c7b25457b 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -282,24 +282,27 @@ def run( from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight from torchao.utils import unwrap_tensor_subclass quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) - unwrap_tensor_subclass(predictor.model.image_encoder) + predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) elif compress == "sparse_mlp_only": def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name from torchao.sparsity import sparsify from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) + sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) elif compress == "sparse": from torchao.sparsity import sparsify from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity apply_fake_sparsity(predictor.model.image_encoder) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured) + sparsify(predictor.model.image_encoder, to_sparse_semi_structured) elif compress == "int8_dynamic_quant_sparse": from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor from torchao.sparsity import sparsify, apply_fake_sparsity - from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight - from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight + from torchao.quantization import ( + quantize_, + int8_dynamic_activation_int8_weight, + int8_dynamic_activation_int8_2x4_sparse_weight, + ) from torchao.utils import unwrap_tensor_subclass def attn_only(mod, name): diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv index 21b29d1dba..92f20b948d 100644 --- a/scripts/sam/results.csv +++ b/scripts/sam/results.csv @@ -23,3 +23,5 @@ cuda,vit_h,32,14869,18,25.998585036559202,38.463631716641515,0.567177072199177,m cuda,vit_h,32,14869,18,25.5705826577316,39.10743894205465,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None cuda,vit_h,32,14869,18,22.687728699962957,44.076690673828125,0.5769663374200948,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,4,128,None,None cuda,vit_h,32,14869,18,26.326282787092705,37.98485369496531,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,26.364732130607624,37.929457998895025,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,26.488956903059197,37.75158091953823,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1cf1bf034..46436f6777 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -32,6 +32,7 @@ "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", + "int8_dynamic_activation_int8_2x4_sparse_weight", "int4_weight_only", "int8_weight_only", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index d6c142476b..1b69912b4e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -14,6 +14,7 @@ come along with it and because that is how we access the intended quantized and mixed GEMM kernels """ +from functools import partial import torch import torchao @@ -412,44 +413,48 @@ def apply_int8wo_quant(weight): return apply_int8wo_quant + +def apply_int8_dynamic_activation_int8_weight_quant(weight, extended_layout="plain"): + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + return weight + + # avoid circular dep + from torchao.dtypes import to_affine_quantized + # weight settings + mapping_type = MappingType.SYMMETRIC + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, extended_layout=extended_layout) + weight = to_linear_act_quantized(weight, input_quant_func) + return weight + def int8_dynamic_activation_int8_weight(): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers """ - def apply_int8_dynamic_activation_int8_weight_quant(weight): - in_features = weight.shape[1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - return weight - - # avoid circular dep - from torchao.dtypes import to_affine_quantized - # weight settings - mapping_type = MappingType.SYMMETRIC - def get_weight_block_size(x): - return (1, x.shape[1]) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - - # input settings - def get_per_token_block_size(x): - block_size = list(x.shape) - for i in range(len(block_size)-1): - block_size[i] = 1 - return block_size - - input_mapping_type = MappingType.SYMMETRIC - input_target_dtype = torch.int8 - input_eps = 1e-5 - input_quant_min = -127 - input_quant_max = 127 - input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) - - block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - weight = to_linear_act_quantized(weight, input_quant_func) - return weight - return apply_int8_dynamic_activation_int8_weight_quant + +def int8_dynamic_activation_int8_2x4_sparse_weight(): + return partial(apply_int8_dynamic_activation_int8_weight_quant, extended_layout="semi_sparse_cusparselt") diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 9b288c07f9..8acb712ed9 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,11 +6,12 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import apply_fake_sparsity, sparsify +from .sparse_api import apply_fake_sparsity, sparsify, int8_dynamic_activation_int8_2x4_sparse_weight __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_fake_sparsity", "sparsify" + "int8_dynamic_activation_int8_2x4_sparse_weight" ] diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py deleted file mode 100644 index d0d80a6b79..0000000000 --- a/torchao/sparsity/prototype/dynamic_quant_sparse.py +++ /dev/null @@ -1,356 +0,0 @@ -import torch -import torch.nn as nn -from typing import Tuple, Optional - -from torchao.quantization.utils import ( - dynamically_quantize_per_channel, - quant_int8_dynamic_per_token_linear, - quantize_activation_per_token_absmax, - dequantize_per_channel, -) - -from torchao.quantization.subclass import ( - Int8DynamicallyQuantizedLinearWeight, - QuantizedLinearWeightBase, -) - -from torch.sparse import to_sparse_semi_structured - -# Quant + Sparse helper functinos -def sparse_quant_int8_dynamic_linear( - x : torch.Tensor, - w_vals_int8_packed : torch.Tensor, - w_meta_int32 : Optional[torch.Tensor], - w_scales : torch.Tensor, - bias : Optional[torch.Tensor], - out_dtype : torch.dtype, - fuse_mul=False, -): - x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - # w_meta_int32 is either None or meta tensor - if w_meta_int32 is None: - if fuse_mul: - mm_out = sparse_quant_int8_cslt_matmul_fuse_mul( - x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, - ) - else: - mm_out = sparse_quant_int8_cslt_matmul( - x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, - ) - else: - mm_out = sparse_quant_int8_cutlass_matmul( - x_vals_int8, x_scales, w_vals_int8_packed, w_meta_int32, w_scales, out_dtype, - ) - - if bias is not None: - mm_out += bias - return mm_out - -def sparse_quant_int8_cslt_matmul_fuse_mul( - x_vals_int8, - x_scales, - w_vals_int8, - w_scales, - out_dtype, -): - - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 - ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - y = y.to(out_dtype) - - return y - -def sparse_quant_int8_cslt_matmul( - x_vals_int8, - x_scales, - w_vals_int8, - w_scales, - out_dtype, -): - - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), out_dtype=torch.bfloat16 - ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1) * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - y = y.to(out_dtype) - - return y - - -def sparse_quant_int8_cutlass_matmul( - x_vals_int8, - x_scales, - w_vals_int8, - w_meta_int32, - w_scales, - out_dtype, -): - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - assert w_scales.dtype == out_dtype, f"{w_scales.dtype} does not match {out_dtype}" - assert w_meta_int32.dtype == torch.int32, f"{w_meta_int32.dtype} not yet supported" - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_int32 = torch._sparse_semi_structured_linear( - tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32 - ) - y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] - ) - y = y.to(out_dtype) - return y - -class Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight( - Int8DynamicallyQuantizedLinearWeight -): - def dequantize(self, dtype=None): - # overload dequantize op for __repr__ - zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype) - int_data_expanded = torch._cslt_sparse_mm(self.int_data, torch.eye(self.shape[1], - dtype=self.int_data.dtype, - device=self.int_data.device)) - dq_t = dequantize_per_channel( - int_data_expanded, self.q_scales, zero_points, self.dtype if dtype is None else dtype - ).to(self.dtype) - - return dq_t if not self.transposed else dq_t.t() - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_linear( - act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype, - fuse_mul=True - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - int_data = torch._cslt_compress(int_data) - - return cls( - int_data, - w_scales, - False, - input_float.shape, - dtype=input_float.dtype, - ) - - -class Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight(QuantizedLinearWeightBase): - - @staticmethod - def __new__(cls, int_data, mask_meta, q_scales, transposed, shape, **kwargs): - kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) - return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, int_data, mask_meta, q_scales, transposed, shape, **kwargs): - self.q_scales = q_scales - self.mask_meta = mask_meta - super().__init__(int_data, transposed) - - def dequantize(self, dtype=None): - """ - Obtain the dequantized version of the quantized tensor subclass - """ - dq_t = dequantize_per_channel( - self.int_data, self.q_scales, 0, self.dtype if dtype is None else dtype - ).to(self.dtype) - # data was transposed to dequantize so make sure shape is correct - return dq_t if not self.transposed else dq_t.t() - - def int_repr(self): - """ - Get the internal integer representation of the quantized tensor - """ - return self.int_data if self.transposed else self.int_data.t() - - def q_params(self): - """ - Get the quantization scales for the quantized tensor - """ - return {"q_scales": self.q_scales} - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.mask_meta.to(kwargs["device"]), - self.q_scales.to(kwargs["device"]), - self.transposed, - self.shape, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.mask_meta), - fn(self.q_scales), - self.transposed, - self.shape, - dtype=self.dtype, - ) - - def _change_shape(self, shape): - return self.__class__( - self.int_data, - self.mask_meta, - self.q_scales, - self.transposed, - shape, - dtype=self.dtype, - ) - - def __tensor_flatten__(self): - return ["int_data", "mask_meta", "q_scales"], [ - self.transposed, - self.dtype, - self.shape, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None - ): - int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] - mask_meta = tensor_data_dict["mask_meta"] - transposed, dtype, shape = tensor_attributes - return cls( - int_data, - mask_meta, - q_scales, - transposed, - shape if outer_size is None else outer_size, - dtype=dtype, - strides=outer_stride, - ) - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_linear( - act_mat, - w_qtensor.int_data, - w_qtensor.mask_meta, - w_qtensor.q_scales, - bias, - act_mat.dtype, - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - sparse_tensor = to_sparse_semi_structured(int_data) - - return cls( - sparse_tensor.packed, - sparse_tensor.meta, - w_scales, - False, - input_float.shape, - dtype=input_float.dtype, - ) - -from torchao.dtypes import to_affine_quantized -from torchao.quantization.quant_api import MappingType, ZeroPointDomain, to_linear_act_quantized - -def int8_dynamic_activation_int8_2x4_sparse_weight(): - """ - Applies int8 dynamic symmetric per-token activation and int8 per-channel weight - quantization to linear layers - """ - def apply_int8_dynamic_activation_int8_2x4_sparse_weight_quant(weight): - in_features = weight.shape[1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - return weight - - # avoid circular dep - from torchao.dtypes import to_affine_quantized - # weight settings - mapping_type = MappingType.SYMMETRIC - def get_weight_block_size(x): - return (1, x.shape[1]) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - - # input settings - def get_per_token_block_size(x): - block_size = list(x.shape) - for i in range(len(block_size)-1): - block_size[i] = 1 - return block_size - - input_mapping_type = MappingType.SYMMETRIC - input_target_dtype = torch.int8 - input_eps = 1e-5 - input_quant_min = -127 - input_quant_max = 127 - input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) - - block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, extended_layout="semi_sparse_cusparselt") - weight = to_linear_act_quantized(weight, input_quant_func) - return weight - - return apply_int8_dynamic_activation_int8_2x4_sparse_weight_quant diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 8f8ca24a39..cc5172ac2f 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -7,6 +7,7 @@ _is_linear, _replace_with_custom_fn_if_matches_filter, _get_linear_subclass_inserter, + int8_dynamic_activation_int8_2x4_sparse_weight, ) # Sparsity helper functions From fd262100a2232335ef49a9ab062872cf586f6193 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 15:09:07 -0700 Subject: [PATCH 05/20] update results --- scripts/sam/benchmark.sh | 3 +-- scripts/sam/eval_combo.py | 10 ++++---- scripts/sam/results.csv | 31 ++++------------------- torchao/dtypes/affine_quantized_tensor.py | 18 ++++++------- torchao/quantization/quant_api.py | 1 + 5 files changed, 20 insertions(+), 43 deletions(-) diff --git a/scripts/sam/benchmark.sh b/scripts/sam/benchmark.sh index 54bc8ef0c0..512ecbb085 100755 --- a/scripts/sam/benchmark.sh +++ b/scripts/sam/benchmark.sh @@ -1,11 +1,10 @@ # baseline # python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True -# int8 dynamic quant (all) +# # int8 dynamic quant (all) # python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant # 2:4 sparsity (all) # python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only # 2:4 sparsity (mlp only) # python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse # int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse) -# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse_other python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index 7c7b25457b..b8eb258729 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -286,17 +286,17 @@ def run( elif compress == "sparse_mlp_only": def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name - from torchao.sparsity import sparsify - from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity + from torchao.sparsity import sparsify, apply_fake_sparsity + from torch.sparse import to_sparse_semi_structured apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) elif compress == "sparse": - from torchao.sparsity import sparsify - from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity + from torchao.sparsity import sparsify, apply_fake_sparsity + from torch.sparse import to_sparse_semi_structured apply_fake_sparsity(predictor.model.image_encoder) sparsify(predictor.model.image_encoder, to_sparse_semi_structured) elif compress == "int8_dynamic_quant_sparse": - from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor + from torch.sparse import to_sparse_semi_structured from torchao.sparsity import sparsify, apply_fake_sparsity from torchao.quantization import ( quantize_, diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv index 92f20b948d..0be02c7f37 100644 --- a/scripts/sam/results.csv +++ b/scripts/sam/results.csv @@ -1,27 +1,6 @@ device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path -cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None -cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15203,18,24.104394829823185,41.48621058773685,0.567190438968895,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15203,18,24.329631814054935,41.10214275508732,0.567190438968895,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None - -cuda,vit_h,32,15211,18,24.56229782147879,40.71280330806584,0.5671033117430888,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15212,18,22.001439043708203,45.45157241821289,0.5719013796228187,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,2,64,None,None -cuda,vit_h,32,14870,18,23.586484684427163,42.39716148376465,0.5728743877698744,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,2,64,None,None -cuda,vit_h,32,14870,18,26.080066556091367,38.343460429798476,0.567100022987887,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14870,18,24.105935759314356,41.483558654785156,0.5728743877698744,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,2,64,None,None -cuda,vit_h,32,14870,18,26.077327811425214,38.347487412489855,0.567100022987887,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,25.878415158662072,38.64224272888976,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14868,18,26.405483383382357,37.870921939998475,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse_other,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,26.405490913715926,37.87091113994648,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,26.638093985525533,37.54022343127758,0.06075761447241241,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,26.308018244576076,38.011224969640956,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,12441,15,26.475001207469894,37.77148080801035,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,25.998585036559202,38.463631716641515,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,25.5705826577316,39.10743894205465,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,22.687728699962957,44.076690673828125,0.5769663374200948,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,4,128,None,None -cuda,vit_h,32,14869,18,26.326282787092705,37.98485369496531,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,26.364732130607624,37.929457998895025,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14869,18,26.488956903059197,37.75158091953823,0.567177072199177,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None +cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index aa4a845e2e..429de8ee6e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -637,7 +637,6 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): if bias is not None: y += bias return y - # handle int8 dynamic_quant + semi_structured_sparse elif( is_cuda and @@ -753,15 +752,14 @@ def functional_linear(*args, **kwargs): # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to # make the branches easier to understand in `_quantized_linear_op` - return _quantized_linear_op(input_tensor, weight_tensor, bias) - # try: - # except: - # pass - # # if isinstance(input_tensor, AffineQuantizedTensor): - # # input_tensor = input_tensor.dequantize() - # # if isinstance(weight_tensor, AffineQuantizedTensor): - # # weight_tensor = weight_tensor.dequantize() - # # return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @implements([aten.mm.default, aten.addmm.default]) def aten_mm(func, *args, **kwargs): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 1b69912b4e..349b3568d7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -58,6 +58,7 @@ "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", + "int8_dynamic_activation_int8_2x4_sparse_weight", "int4_weight_only", "int8_weight_only", ] From ae8b20692e831324bcae04c952c4c0f8f0ac1b49 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 15:14:41 -0700 Subject: [PATCH 06/20] rename sparsify_ --- test/sparsity/test_sparse_api.py | 7 +++---- torchao/sparsity/__init__.py | 4 ++-- torchao/sparsity/sparse_api.py | 8 ++++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 3e566732bb..fcd0384805 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -5,8 +5,7 @@ from torch import nn from torch.sparse import to_sparse_semi_structured -from torchao.sparsity import apply_fake_sparsity, sparsify -from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_2x4_sparse_weight from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, @@ -38,7 +37,7 @@ def test_sparse(self): apply_fake_sparsity(model) dense_result = model(input) - model = sparsify(model, to_sparse_semi_structured) + model = sparsify_(model, to_sparse_semi_structured) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @@ -62,7 +61,7 @@ def test_quant_semi_sparse(self): apply_fake_sparsity(model) dense_result = model(input) - sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight()) + sparsify_(model, int8_dynamic_activation_int8_2x4_sparse_weight()) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 8acb712ed9..caecb1d55c 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,12 +6,12 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import apply_fake_sparsity, sparsify, int8_dynamic_activation_int8_2x4_sparse_weight +from .sparse_api import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_2x4_sparse_weight __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_fake_sparsity", - "sparsify" + "sparsify_" "int8_dynamic_activation_int8_2x4_sparse_weight" ] diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index cc5172ac2f..f6d1195cbe 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -31,7 +31,7 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.squash_mask() -def sparsify(model: torch.nn.Module, +def sparsify_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` @@ -50,7 +50,7 @@ def sparsify(model: torch.nn.Module, Example:: import torch import torch.nn as nn - from torchao.sparsity import sparsify + from torchao.sparsity import sparsify_ def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) @@ -59,11 +59,11 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: # for 2:4 sparsity from torch.sparse import to_sparse_semi_structured - m = sparsify(m, to_sparse_semi_structured, filter_fn) + m = sparsify_(m, to_sparse_semi_structured, filter_fn) # for int8 dynamic quantization + 2:4 sparsity from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight - m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn) + m = sparsify_(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model, From b13c6f4effa50d306ee002d92840732579977c7b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 15:24:04 -0700 Subject: [PATCH 07/20] undo benchmark script --- scripts/sam/benchmark.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/sam/benchmark.sh b/scripts/sam/benchmark.sh index 512ecbb085..c52ce33151 100755 --- a/scripts/sam/benchmark.sh +++ b/scripts/sam/benchmark.sh @@ -1,10 +1,10 @@ # baseline -# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True -# # int8 dynamic quant (all) -# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True +# int8 dynamic quant (all) +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant # 2:4 sparsity (all) -# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only # 2:4 sparsity (mlp only) -# python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse # int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse) python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse From 514b74cd436143fd962ae14403fd1b67b09ede88 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 15:28:01 -0700 Subject: [PATCH 08/20] update eval_combo.py --- scripts/sam/eval_combo.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index b8eb258729..be9857f629 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -9,6 +9,11 @@ import time import resource +from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight +from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.utils import unwrap_tensor_subclass +from torch.sparse import to_sparse_semi_structured + torch._dynamo.config.cache_size_limit = 50000 def unbind_jagged(device, data, sizes, offsets): @@ -279,32 +284,17 @@ def run( block.attn.use_rel_pos = use_rel_pos if compress == "int8_dynamic_quant": - from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight - from torchao.utils import unwrap_tensor_subclass quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) elif compress == "sparse_mlp_only": def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name - from torchao.sparsity import sparsify, apply_fake_sparsity - from torch.sparse import to_sparse_semi_structured apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) elif compress == "sparse": - from torchao.sparsity import sparsify, apply_fake_sparsity - from torch.sparse import to_sparse_semi_structured apply_fake_sparsity(predictor.model.image_encoder) - sparsify(predictor.model.image_encoder, to_sparse_semi_structured) + sparsify_(predictor.model.image_encoder, to_sparse_semi_structured) elif compress == "int8_dynamic_quant_sparse": - from torch.sparse import to_sparse_semi_structured - from torchao.sparsity import sparsify, apply_fake_sparsity - from torchao.quantization import ( - quantize_, - int8_dynamic_activation_int8_weight, - int8_dynamic_activation_int8_2x4_sparse_weight, - ) - from torchao.utils import unwrap_tensor_subclass - def attn_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'attn' in name def mlp_lin1_only(mod, name): @@ -327,9 +317,9 @@ def mlp_only(mod, name): predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, - to_sparse_semi_structured, - mlp_lin2_only) + predictor.model.image_encoder = sparsify_(predictor.model.image_encoder, + to_sparse_semi_structured, + mlp_lin2_only) else: assert compress is None, f"Unsupported compress mode {compress}" From f8fb6aac75c030013771821996573dd77b35e83c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 11 Jul 2024 15:36:58 -0700 Subject: [PATCH 09/20] fix eval --- scripts/sam/eval_combo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index be9857f629..5dc14b6e60 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -290,7 +290,7 @@ def run( def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) + sparsify_(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) elif compress == "sparse": apply_fake_sparsity(predictor.model.image_encoder) sparsify_(predictor.model.image_encoder, to_sparse_semi_structured) From 2086394c770fba5cdb688522ae7f8ac3436ae48d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 18 Jul 2024 12:18:53 -0700 Subject: [PATCH 10/20] update --- README.md | 11 +++++------ test/sparsity/test_sparse_api.py | 23 +++++++++++++++-------- torchao/dtypes/affine_quantized_tensor.py | 6 ++++-- torchao/sparsity/sparse_api.py | 15 +++++++++------ 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 1dd2a72340..f60f4c5dc7 100644 --- a/README.md +++ b/README.md @@ -49,20 +49,19 @@ And a quick crash course on inference quantization to help parse the above table Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers. ```python -from torchao.sparsity import sparsify -from torch.sparse import to_sparse_semi_structured +from torchao.sparsity import sparsify, semi_sparse_weight -m = sparsify(m, to_sparse_semi_structured) +m = sparsify(m, semi_sparse_weight) ``` Sparsity can also be composed with int8 dynamic quantization for further speedups: ```python from torchao.sparsity import sparsify -from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_semi_sparse_weight -m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight()) +m = sparsify(m, int8_dynamic_activation_int8_semi_sparse_weight()) ``` -We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. +We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + semi sparse (2:4) sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**. The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`: diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index fcd0384805..d831411fb1 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -1,17 +1,25 @@ import logging import unittest +import copy import torch from torch import nn -from torch.sparse import to_sparse_semi_structured from torchao.sparsity import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity.sparse_api import semi_sparse_weight +from torchao.utils import unwrap_tensor_subclass from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, _is_linear, + int8_dynamic_activation_int8_weight, + quantize_, ) -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.quantization.subclass import ( + LinearActQuantizedTensor, +) +from torchao.dtypes import AffineQuantizedTensor +from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass from torch.testing._internal.common_utils import TestCase @@ -37,12 +45,11 @@ def test_sparse(self): apply_fake_sparsity(model) dense_result = model(input) - model = sparsify_(model, to_sparse_semi_structured) + sparsify_(model, semi_sparse_weight()) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - class TestQuantSemiSparse(TestCase): @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") @@ -57,15 +64,15 @@ def test_quant_semi_sparse(self): .half() .cuda() ) - apply_fake_sparsity(model) - dense_result = model(input) + model_copy = copy.deepcopy(model) + quantize_(model_copy, int8_dynamic_activation_int8_weight()) + dense_result = model_copy(input) sparsify_(model, int8_dynamic_activation_int8_2x4_sparse_weight()) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) - + assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2) if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 4e78bb0d26..4f51004d1a 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -34,6 +34,7 @@ class PlainLayoutType(LayoutType): @dataclass(frozen=True) class SparseLayoutType(LayoutType): + def post_process(self, input: torch.Tensor) -> torch.Tensor: return torch._cslt_compress(input) @@ -498,7 +499,8 @@ def __new__( ) kwargs["dtype"] = int_data.dtype kwargs["requires_grad"] = False - shape = torch.Size([zero_point.shape[0], int_data.numel() * 16 // (10 * zero_point.shape[0])]) + shape = torch.Size([zero_point.shape[0], + int_data.numel() * 16 // (10 * zero_point.shape[0])]) return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] @classmethod @@ -511,7 +513,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" + f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported" ) def get_plain(self): diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index f6d1195cbe..4088b0ad08 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -30,6 +30,11 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.step() sparsifier.squash_mask() +def semi_sparse_weight(): + """ + Convert the weight of linear moduels to semi-structured (2:4) sparsity + """ + return _get_linear_subclass_inserter(to_sparse_semi_structured) def sparsify_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], @@ -38,7 +43,7 @@ def sparsify_(model: torch.nn.Module, This function is essentially the same as quantize, put for sparsity subclasses. Currently, we support two options for sparsity: - - semi-structured (2:4) sparsity with `to_sparse_semi_structured` + - semi-structured (2:4) sparsity with `semi_sparse_weight` - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API Args: @@ -58,8 +63,8 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) # for 2:4 sparsity - from torch.sparse import to_sparse_semi_structured - m = sparsify_(m, to_sparse_semi_structured, filter_fn) + from torchao.sparse_api import semi_sparse_weight + m = sparsify_(m, semi_sparse_weight(), filter_fn) # for int8 dynamic quantization + 2:4 sparsity from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight @@ -67,8 +72,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: """ _replace_with_custom_fn_if_matches_filter( model, - _get_linear_subclass_inserter(apply_tensor_subclass), + apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, ) - - return model From 5f97b8894d8b6aebb9bd3b330b8a2c1998c8b342 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 18 Jul 2024 12:23:15 -0700 Subject: [PATCH 11/20] update README --- README.md | 7 +++---- scripts/sam/eval_combo.py | 7 +------ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f60f4c5dc7..ff3f03294e 100644 --- a/README.md +++ b/README.md @@ -51,15 +51,14 @@ Sparsifying your model is also a 1 liner that should work on any model with an ` ```python from torchao.sparsity import sparsify, semi_sparse_weight -m = sparsify(m, semi_sparse_weight) +m = sparsify_(m, semi_sparse_weight) ``` Sparsity can also be composed with int8 dynamic quantization for further speedups: ```python -from torchao.sparsity import sparsify -from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_semi_sparse_weight +from torchao.sparsity import sparsify, int8_dynamic_activation_int8_semi_sparse_weight -m = sparsify(m, int8_dynamic_activation_int8_semi_sparse_weight()) +m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight()) ``` We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + semi sparse (2:4) sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**. diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index 0c692cfb59..f178f7b6fc 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -320,11 +320,6 @@ def mlp_only(mod, name): predictor.model.image_encoder = sparsify_(predictor.model.image_encoder, to_sparse_semi_structured, mlp_lin2_only) - elif compress == "int4_weight_only_quant_sparse": - apply_fake_sparsity(predictor.model.image_encoder) - quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) - predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) - else: assert compress is None, f"Unsupported compress mode {compress}" @@ -377,7 +372,7 @@ def mlp_only(mod, name): batch_size, use_compile, use_compile_decoder, - pad_input_image_batch, + pad_input_image_batch, compress) results = [[r[0], r[1], r[2], r[3].item()] for r in results] From 0971d92ab945d9aa80a37d35472e1ad9a50f3750 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 18 Jul 2024 13:28:04 -0700 Subject: [PATCH 12/20] update READMEs --- README.md | 4 ++-- test/sparsity/test_sparse_api.py | 4 ++-- torchao/dtypes/__init__.py | 4 ++-- torchao/dtypes/affine_quantized_tensor.py | 19 ++++++++++++------- torchao/quantization/quant_api.py | 9 +++++---- torchao/sparsity/__init__.py | 4 ++-- torchao/sparsity/sparse_api.py | 8 ++++---- 7 files changed, 29 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index ff3f03294e..e31dc63a8f 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,9 @@ And a quick crash course on inference quantization to help parse the above table Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers. ```python -from torchao.sparsity import sparsify, semi_sparse_weight +from torchao.sparsity import sparsify, semi_sparse_weight() -m = sparsify_(m, semi_sparse_weight) +m = sparsify_(m, semi_sparse_weight()) ``` Sparsity can also be composed with int8 dynamic quantization for further speedups: diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index d831411fb1..1329a03e1b 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -5,7 +5,7 @@ import torch from torch import nn -from torchao.sparsity import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_semi_sparse_weight from torchao.sparsity.sparse_api import semi_sparse_weight from torchao.utils import unwrap_tensor_subclass from torchao.quantization.quant_api import ( @@ -69,7 +69,7 @@ def test_quant_semi_sparse(self): quantize_(model_copy, int8_dynamic_activation_int8_weight()) dense_result = model_copy(input) - sparsify_(model, int8_dynamic_activation_int8_2x4_sparse_weight()) + quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight()) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 6c3e1c930e..e4b47b8229 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -7,8 +7,8 @@ to_affine_quantized_static, LayoutType, PlainLayoutType, + SemiSparseLayoutType, TensorCoreTiledLayoutType, - SparseAQTLayout ) __all__ = [ @@ -20,6 +20,6 @@ "to_affine_quantized_static", "LayoutType", "PlainLayoutType", + "SemiSparseLayoutType", "TensorCoreTiledLayoutType", - "SparseAQTLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 4f51004d1a..ace11c84cf 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -31,14 +31,19 @@ class PlainLayoutType(LayoutType): pass - @dataclass(frozen=True) -class SparseLayoutType(LayoutType): +class SemiSparseLayoutType(LayoutType): + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + # prune to 2:4 if not already + temp = input.detach() + pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] + temp.view(-1, 4).scatter_(1, pruning_inds, value=0) + return temp def post_process(self, input: torch.Tensor) -> torch.Tensor: return torch._cslt_compress(input) - @dataclass(frozen=True) class TensorCoreTiledLayoutType(LayoutType): inner_k_tiles: int = 8 @@ -480,8 +485,8 @@ def from_plain( assert isinstance(layout_type, PlainLayoutType) return cls(int_data, scale, zero_point, layout_type) -@register_layout_cls(SparseLayoutType) -class SparseAQTLayout(PlainAQTLayout): +@register_layout_cls(SemiSparseLayoutType) +class SemiSparseAQTLayout(PlainAQTLayout): """ Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor """ @@ -531,7 +536,7 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - assert isinstance(layout_type, SparseLayoutType) + assert isinstance(layout_type, SemiSparseLayoutType) return cls(int_data, scale, zero_point, layout_type) @@ -737,7 +742,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): input_is_int8 and input_tensor.dtype == weight_qtensor.dtype and isinstance(input_tensor.layout_type, PlainLayoutType) and - isinstance(weight_qtensor.layout_type, SparseLayoutType) + isinstance(weight_qtensor.layout_type, SemiSparseLayoutType) ): x_vals_int8 = input_tensor.layout_tensor.int_data x_scales = input_tensor.layout_tensor.scale diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f39fd84a60..953b105dee 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -58,7 +58,7 @@ "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_2x4_sparse_weight", + "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", ] @@ -464,7 +464,8 @@ def get_per_token_block_size(x): return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) -def int8_dynamic_activation_int8_2x4_sparse_weight(): + +def int8_dynamic_activation_int8_semi_sparse_weight(): """ Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. @@ -477,7 +478,7 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized - from torchao.dtypes.affine_quantized_tensor import SparseLayoutType + from torchao.dtypes.affine_quantized_tensor import SemiSparseLayoutType # weight settings mapping_type = MappingType.SYMMETRIC def get_weight_block_size(x): @@ -501,7 +502,7 @@ def get_per_token_block_size(x): input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=SparseLayoutType()) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=SemiSparseLayoutType()) weight = to_linear_act_quantized(weight, input_quant_func) return weight diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index caecb1d55c..4540c8b7c3 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,12 +6,12 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_2x4_sparse_weight +from .sparse_api import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_semi_sparse_weight __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_fake_sparsity", "sparsify_" - "int8_dynamic_activation_int8_2x4_sparse_weight" + "int8_dynamic_activation_int8_semi_sparse_weight" ] diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 4088b0ad08..a12d954422 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -7,7 +7,7 @@ _is_linear, _replace_with_custom_fn_if_matches_filter, _get_linear_subclass_inserter, - int8_dynamic_activation_int8_2x4_sparse_weight, + int8_dynamic_activation_int8_semi_sparse_weight, ) # Sparsity helper functions @@ -44,7 +44,7 @@ def sparsify_(model: torch.nn.Module, Currently, we support two options for sparsity: - semi-structured (2:4) sparsity with `semi_sparse_weight` - - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API + - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_semi_sparse_weight`, which is also available via the quantize API Args: model (torch.nn.Module): input model @@ -67,8 +67,8 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: m = sparsify_(m, semi_sparse_weight(), filter_fn) # for int8 dynamic quantization + 2:4 sparsity - from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight - m = sparsify_(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn) + from torchao.sparsity.prototype import int8_dynamic_activation_int8_semi_sparse_weight + m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight(), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model, From e7608cf1efc333dd3dda7ddfc14a329ffd1b708f Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 18 Jul 2024 13:29:55 -0700 Subject: [PATCH 13/20] update --- torchao/quantization/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 46436f6777..6bf37f0080 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -32,7 +32,7 @@ "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_2x4_sparse_weight", + "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", ] From 1d3b2cd117718daed0bdd32f58fb894a96b9d0a9 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 18 Jul 2024 13:42:18 -0700 Subject: [PATCH 14/20] update eval_combo --- scripts/sam/eval_combo.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index f178f7b6fc..4be0109670 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -10,9 +10,8 @@ import resource from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only -from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight from torchao.utils import unwrap_tensor_subclass -from torch.sparse import to_sparse_semi_structured torch._dynamo.config.cache_size_limit = 50000 @@ -290,10 +289,10 @@ def run( def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - sparsify_(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) + sparsify_(predictor.model.image_encoder, semi_sparse_weight(), filter_fn=mlp_only) elif compress == "sparse": apply_fake_sparsity(predictor.model.image_encoder) - sparsify_(predictor.model.image_encoder, to_sparse_semi_structured) + sparsify_(predictor.model.image_encoder, semi_sparse_weight()) elif compress == "int8_dynamic_quant_sparse": def attn_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'attn' in name @@ -318,7 +317,7 @@ def mlp_only(mod, name): predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) predictor.model.image_encoder = sparsify_(predictor.model.image_encoder, - to_sparse_semi_structured, + semi_sparse_weight(), mlp_lin2_only) else: assert compress is None, f"Unsupported compress mode {compress}" From 0d5907c2c81b08ea93c23104a88ef0f9c0848bf7 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 18 Jul 2024 13:52:05 -0700 Subject: [PATCH 15/20] update --- torchao/quantization/quant_api.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 953b105dee..92cefd8f25 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -14,8 +14,6 @@ come along with it and because that is how we access the intended quantized and mixed GEMM kernels """ -from functools import partial - import torch import torchao import torch.nn as nn @@ -421,7 +419,6 @@ def apply_int8wo_quant(weight): return _get_linear_subclass_inserter(apply_int8wo_quant) - def int8_dynamic_activation_int8_weight(): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight @@ -470,15 +467,14 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - def apply_int8_dynamic_activation_int8_weight_quant(weight): + def apply_int8_dynamic_activation_int8_semi_sparse_weight_quant(weight): in_features = weight.shape[1] # int8 dynamic quantization only has benefit when in_feature > 16 if in_features <= 16: return weight # avoid circular dep - from torchao.dtypes import to_affine_quantized - from torchao.dtypes.affine_quantized_tensor import SemiSparseLayoutType + from torchao.dtypes import to_affine_quantized, SemiSparseLayoutType # weight settings mapping_type = MappingType.SYMMETRIC def get_weight_block_size(x): @@ -502,8 +498,9 @@ def get_per_token_block_size(x): input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=SemiSparseLayoutType()) + layout_type = SemiSparseLayoutType() + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) weight = to_linear_act_quantized(weight, input_quant_func) return weight - return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) + return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_semi_sparse_weight_quant) From c1797cb870c3ed104e68d86912054c902f216d58 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 18 Jul 2024 13:55:51 -0700 Subject: [PATCH 16/20] update init --- test/sparsity/test_sparse_api.py | 15 +++++++-------- torchao/sparsity/__init__.py | 8 +++++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 1329a03e1b..b846afa454 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -1,13 +1,16 @@ +import copy import logging import unittest -import copy import torch from torch import nn -from torchao.sparsity import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_semi_sparse_weight -from torchao.sparsity.sparse_api import semi_sparse_weight -from torchao.utils import unwrap_tensor_subclass +from torchao.sparsity import ( + apply_fake_sparsity, + sparsify_, + int8_dynamic_activation_int8_semi_sparse_weight, + semi_sparse_weight, +) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, @@ -15,10 +18,6 @@ int8_dynamic_activation_int8_weight, quantize_, ) -from torchao.quantization.subclass import ( - LinearActQuantizedTensor, -) -from torchao.dtypes import AffineQuantizedTensor from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass from torch.testing._internal.common_utils import TestCase diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 4540c8b7c3..c3b10f949a 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,12 +6,18 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import apply_fake_sparsity, sparsify_, int8_dynamic_activation_int8_semi_sparse_weight +from .sparse_api import ( + apply_fake_sparsity, + sparsify_, + semi_sparse_weight, + int8_dynamic_activation_int8_semi_sparse_weight +) __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_fake_sparsity", "sparsify_" + "semi_sparse_weight", "int8_dynamic_activation_int8_semi_sparse_weight" ] From 17f0ea1c9487a69b6538c83b5f816acd913d3f0b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 24 Jul 2024 06:59:30 -0700 Subject: [PATCH 17/20] update --- scripts/sam/eval_combo.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index 4be0109670..b9733bd98b 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -311,14 +311,13 @@ def mlp_only(mod, name): int8_dynamic_activation_int8_weight(), attn_only) quantize_(predictor.model.image_encoder, - int8_dynamic_activation_int8_2x4_sparse_weight(), + int8_dynamic_activation_int8_semi_sparse_weight(), mlp_lin1_only) - + sparsify_(predictor.model.image_encoder, + semi_sparse_weight(), + mlp_lin2_only) predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) - predictor.model.image_encoder = sparsify_(predictor.model.image_encoder, - semi_sparse_weight(), - mlp_lin2_only) else: assert compress is None, f"Unsupported compress mode {compress}" From dc0ab1641bb61d2c4e0a6484d0f89ffc4e670b26 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 24 Jul 2024 07:18:11 -0700 Subject: [PATCH 18/20] update --- torchao/quantization/quant_api.py | 119 ++++++++++++------------------ 1 file changed, 46 insertions(+), 73 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 92cefd8f25..e2e6874b31 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -14,12 +14,14 @@ come along with it and because that is how we access the intended quantized and mixed GEMM kernels """ +from functools import partial import torch import torchao import torch.nn as nn import torch.nn.functional as F from typing import Any, Callable, Union, Dict, Optional +from torchao.dtypes.utils import LayoutType from torchao.utils import ( TORCH_VERSION_AFTER_2_4, unwrap_tensor_subclass, @@ -419,47 +421,52 @@ def apply_int8wo_quant(weight): return _get_linear_subclass_inserter(apply_int8wo_quant) +def _apply_int8_dynamic_activation_int8_weight_quant(weight, layout_type : LayoutType): + """ + Helper function for int8 dynamic activation w/wo 2:4 sparsity. + """ + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + return weight + + # avoid circular dep + from torchao.dtypes import to_affine_quantized + # weight settings + mapping_type = MappingType.SYMMETRIC + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) + weight = to_linear_act_quantized(weight, input_quant_func) + return weight + def int8_dynamic_activation_int8_weight(): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers """ - def apply_int8_dynamic_activation_int8_weight_quant(weight): - in_features = weight.shape[1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - return weight - - # avoid circular dep - from torchao.dtypes import to_affine_quantized - # weight settings - mapping_type = MappingType.SYMMETRIC - def get_weight_block_size(x): - return (1, x.shape[1]) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - - # input settings - def get_per_token_block_size(x): - block_size = list(x.shape) - for i in range(len(block_size)-1): - block_size[i] = 1 - return block_size - - input_mapping_type = MappingType.SYMMETRIC - input_target_dtype = torch.int8 - input_eps = 1e-5 - input_quant_min = -127 - input_quant_max = 127 - input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) - - block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - weight = to_linear_act_quantized(weight, input_quant_func) - return weight - - return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) + from torchao.dtypes import PlainLayoutType + _apply_int8_dynamic_activation_int8_weight_quant_layout = partial(_apply_int8_dynamic_activation_int8_weight_quant, layout_type=PlainLayoutType()) + return _get_linear_subclass_inserter(_apply_int8_dynamic_activation_int8_weight_quant_layout) def int8_dynamic_activation_int8_semi_sparse_weight(): @@ -467,40 +474,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - def apply_int8_dynamic_activation_int8_semi_sparse_weight_quant(weight): - in_features = weight.shape[1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - return weight - - # avoid circular dep - from torchao.dtypes import to_affine_quantized, SemiSparseLayoutType - # weight settings - mapping_type = MappingType.SYMMETRIC - def get_weight_block_size(x): - return (1, x.shape[1]) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - - # input settings - def get_per_token_block_size(x): - block_size = list(x.shape) - for i in range(len(block_size)-1): - block_size[i] = 1 - return block_size - - input_mapping_type = MappingType.SYMMETRIC - input_target_dtype = torch.int8 - input_eps = 1e-5 - input_quant_min = -127 - input_quant_max = 127 - input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) - - block_size = get_weight_block_size(weight) - layout_type = SemiSparseLayoutType() - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) - weight = to_linear_act_quantized(weight, input_quant_func) - return weight - - return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_semi_sparse_weight_quant) + from torchao.dtypes import SemiSparseLayoutType + _apply_int8_dynamic_activation_int8_weight_quant_layout = partial(_apply_int8_dynamic_activation_int8_weight_quant, layout_type=SemiSparseLayoutType()) + return _get_linear_subclass_inserter(_apply_int8_dynamic_activation_int8_weight_quant_layout) From c5444497d728fa1e97418ca25447a91734b7b9c4 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 24 Jul 2024 07:58:11 -0700 Subject: [PATCH 19/20] updated for packed shape --- torchao/dtypes/affine_quantized_tensor.py | 28 +++++------------------ torchao/quantization/quant_api.py | 3 ++- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index ace11c84cf..807a588aed 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -41,8 +41,6 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: temp.view(-1, 4).scatter_(1, pruning_inds, value=0) return temp - def post_process(self, input: torch.Tensor) -> torch.Tensor: - return torch._cslt_compress(input) @dataclass(frozen=True) class TensorCoreTiledLayoutType(LayoutType): @@ -490,24 +488,6 @@ class SemiSparseAQTLayout(PlainAQTLayout): """ Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor """ - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout_type: LayoutType, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = torch.Size([zero_point.shape[0], - int_data.numel() * 16 // (10 * zero_point.shape[0])]) - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -522,8 +502,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def get_plain(self): + # Currently we don't have cuSPARSELt expansion routines, so we matmul by + # the identity matrix to get the original dense matrix. This is slow though. + cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) int_data_expanded = torch._cslt_sparse_mm(self.int_data, - torch.eye(self.shape[1], + torch.eye(cols, dtype=self.int_data.dtype, device=self.int_data.device).t()) return int_data_expanded, self.scale, self.zero_point @@ -537,7 +520,8 @@ def from_plain( layout_type: LayoutType, ): assert isinstance(layout_type, SemiSparseLayoutType) - return cls(int_data, scale, zero_point, layout_type) + int_data_compressed = torch._cslt_compress(int_data) + return cls(int_data_compressed, scale, zero_point, layout_type) @register_layout_cls(TensorCoreTiledLayoutType) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e2e6874b31..72c9a1b713 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -423,7 +423,8 @@ def apply_int8wo_quant(weight): def _apply_int8_dynamic_activation_int8_weight_quant(weight, layout_type : LayoutType): """ - Helper function for int8 dynamic activation w/wo 2:4 sparsity. + Helper function to specify layout_type for int8 dynamic activation int8 dynamic weight quantization. + Used to compose with semi-structured sparsity. """ in_features = weight.shape[1] # int8 dynamic quantization only has benefit when in_feature > 16 From ce567c3632064424011c9c2165ba1f723ab48766 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 25 Jul 2024 13:50:27 -0700 Subject: [PATCH 20/20] refactor quant api --- torchao/quantization/quant_api.py | 84 ++++++++++++++----------------- 1 file changed, 39 insertions(+), 45 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b588620b3b..161a84c4e4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from typing import Any, Callable, Union, Dict, Optional -from torchao.dtypes.utils import LayoutType +from torchao.dtypes import PlainLayoutType from torchao.utils import ( TORCH_VERSION_AFTER_2_4, unwrap_tensor_subclass, @@ -412,53 +412,48 @@ def apply_int8wo_quant(weight): return _get_linear_subclass_inserter(apply_int8wo_quant) -def _apply_int8_dynamic_activation_int8_weight_quant(weight : torch.Tensor, layout_type : LayoutType) -> torch.Tensor: - """ - Helper function to specify layout_type for int8 dynamic activation int8 dynamic weight quantization. - Used to compose with semi-structured sparsity. - """ - in_features = weight.shape[1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - return weight - # avoid circular dep - from torchao.dtypes import to_affine_quantized - # weight settings - mapping_type = MappingType.SYMMETRIC - def get_weight_block_size(x): - return (1, x.shape[1]) - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - - # input settings - def get_per_token_block_size(x): - block_size = list(x.shape) - for i in range(len(block_size)-1): - block_size[i] = 1 - return block_size - - input_mapping_type = MappingType.SYMMETRIC - input_target_dtype = torch.int8 - input_eps = 1e-5 - input_quant_min = -127 - input_quant_max = 127 - input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) - - block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) - weight = to_linear_act_quantized(weight, input_quant_func) - return weight - -def int8_dynamic_activation_int8_weight(): +def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers """ - from torchao.dtypes import PlainLayoutType - _apply_int8_dynamic_activation_int8_weight_quant_layout = partial(_apply_int8_dynamic_activation_int8_weight_quant, layout_type=PlainLayoutType()) - return _get_linear_subclass_inserter(_apply_int8_dynamic_activation_int8_weight_quant_layout) + def apply_int8_dynamic_activation_int8_weight_quant(weight): + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + return weight + + # avoid circular dep + from torchao.dtypes import to_affine_quantized + # weight settings + mapping_type = MappingType.SYMMETRIC + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) + weight = to_linear_act_quantized(weight, input_quant_func) + return weight + + return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) def int8_dynamic_activation_int8_semi_sparse_weight(): @@ -467,5 +462,4 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): quantization + 2:4 sparsity to linear layers. """ from torchao.dtypes import SemiSparseLayoutType - _apply_int8_dynamic_activation_int8_weight_quant_layout = partial(_apply_int8_dynamic_activation_int8_weight_quant, layout_type=SemiSparseLayoutType()) - return _get_linear_subclass_inserter(_apply_int8_dynamic_activation_int8_weight_quant_layout) + return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())