diff --git a/README.md b/README.md index 0da273f91c..e3cdc60aba 100644 --- a/README.md +++ b/README.md @@ -29,16 +29,16 @@ For inference, we have the option of ```python from torchao.quantization.quant_api import ( quantize_, - int8_dynamic_activation_int8_weight, - int4_weight_only, - int8_weight_only + Int8DynamicActivationInt8WeightConfig, + Int4WeightOnlyConfig, + Int8WeightOnlyConfig ) -quantize_(m, int4_weight_only()) +quantize_(m, Int4WeightOnlyConfig()) ``` -For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. +For gpt-fast `Int4WeightOnlyConfig()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. -If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, int8_weight_only(), device="cuda")` which will send and quantize each layer individually to your GPU. +If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, Int8WeightOnlyConfig(), device="cuda")` which will send and quantize each layer individually to your GPU. If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer. @@ -63,12 +63,12 @@ Post-training quantization can result in a fast and compact model, but may also ```python from torchao.quantization import ( quantize_, - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import ( FakeQuantizeConfig, - from_intx_quantization_aware_training, - intx_quantization_aware_training, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, ) # Insert fake quantization @@ -76,14 +76,14 @@ activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=Fal weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( my_model, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) # Run training... (not shown) # Convert fake quantization to actual quantized operations -quantize_(my_model, from_intx_quantization_aware_training()) -quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) +quantize_(my_model, FromIntXQuantizationAwareTrainingConfig()) +quantize_(my_model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` ### Float8 @@ -139,7 +139,7 @@ The best example we have combining the composability of lower bit dtype with com We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow -1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` +1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))` 2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index ace4d8c14c..655a942718 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -82,7 +82,7 @@ model(input) When used as in the example above, when the `autoquant` api is called alongside torch.compile, autoquant sets up the model so that when its run on the next input, the autoquantization and torch.compile processes leave you with a heavily optimized model. -When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow. +When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `Int4WeightOnlyConfig()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow. Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. @@ -109,13 +109,13 @@ be applied individually. While there are a large variety of quantization apis, t ```python # for torch 2.4+ -from torchao.quantization import quantize_, int4_weight_only +from torchao.quantization import quantize_, Int4WeightOnlyConfig group_size = 32 # you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through -# use_hqq flag for `int4_weight_only` quantization +# use_hqq flag for `Int4WeightOnlyConfig` quantization use_hqq = False -quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) +quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors @@ -128,8 +128,8 @@ Note: The quantization error incurred by applying int4 quantization to your mode ```python # for torch 2.4+ -from torchao.quantization import quantize_, int8_weight_only -quantize_(model, int8_weight_only()) +from torchao.quantization import quantize_, Int8WeightOnlyConfig +quantize_(model, Int8WeightOnlyConfig()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors @@ -140,8 +140,8 @@ change_linear_weights_to_int8_woqtensors(model) ```python # for torch 2.4+ -from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight -quantize_(model, int8_dynamic_activation_int8_weight()) +from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig +quantize_(model, Int8DynamicActivationInt8WeightConfig()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors @@ -152,8 +152,8 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.5+ -from torchao.quantization import quantize_, float8_weight_only -quantize_(model, float8_weight_only()) +from torchao.quantization import quantize_, Float8WeightOnlyConfig +quantize_(model, Float8WeightOnlyConfig()) ``` Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -162,8 +162,8 @@ Supports all dtypes for original weight and activation. This API is only tested ```python # for torch 2.4+ -from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor -quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) +from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor +quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())) ``` Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -172,8 +172,8 @@ Supports all dtypes for original weight and activation. This API is only tested ```python # for torch 2.5+ -from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight -quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow())) +from torchao.quantization import quantize_, PerRow, Float8DynamicActivationFloat8WeightConfig +quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -182,14 +182,14 @@ Per-row scaling is only supported for bfloat16 weight and activation. This API i ```python # for torch 2.4+ -from torchao.quantization import quantize_, fpx_weight_only -quantize_(model, fpx_weight_only(3, 2)) +from torchao.quantization import quantize_, FPXWeightOnlyConfig +quantize_(model, FPXWeightOnlyConfig(3, 2)) ``` You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype. ## Affine Quantization Details -Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_preicsion_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. +Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. ### Quantization Primitives We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. @@ -200,7 +200,7 @@ Note: these primitive ops supports two "types" of quantization, distinguished by We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) #### Layouts -We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. +We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for workflows backing `Int8WeightOnlyConfig` and `Int8DynamicActivationInt8WeightConfig` and also as a default layout. `tensor_core_tiled` layout is used for workflows backing `Int4WeightOnlyConfig` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. ### Zero Point Domains ```ZeroPointDomain``` is used to control the data types of zero points. ```ZeroPointDomain.None``` means zero_point is None, ```ZeroPointDomain.FLOAT``` means zero_point is in the floating point domain and ```ZeroPointDomain.INT``` means integer domain. For detailed implementation of different zero point data types, refer to [the reference implementation](../../test/quantization/test_quant_primitives.py). @@ -223,7 +223,7 @@ from torchao.dtypes import to_affine_quantized_intx import copy from torchao.quantization.quant_api import ( quantize_, - int4_weight_only, + Int4WeightOnlyConfig, ) class ToyLinearModel(torch.nn.Module): @@ -249,9 +249,9 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -quantize_(m, int4_weight_only(group_size=group_size)) +quantize_(m, Int4WeightOnlyConfig(group_size=group_size)) ## If different zero_point_domain needed -# quantize_(m, int4_weight_only(group_size=group_size), zero_point_domain=ZeroPointDomain.FLOAT) +# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT)) # temporary workaround for tensor subclass + torch.compile # NOTE: this is only need for torch version < 2.5+ @@ -360,7 +360,7 @@ We're trying to develop kernels for low bit quantization for intx quantization f | | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 | | | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 | -You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. +You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `torchao/_models/llama/generate.py`. ### int8_dynamic_activation_intx_weight Quantization We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used. diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 813b628af7..0f024dbf61 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -71,9 +71,9 @@ def train_loop(m: torch.nn.Module): The recommended way to run QAT in torchao is through the `quantize_` API: 1. **Prepare:** specify how weights and/or activations are to be quantized through -[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242) +[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242) 2. **Convert:** quantize the model using the standard post-training quantization (PTQ) -functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606) +functions such as [`Int8DynamicActivationInt4WeightConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606) For example: @@ -81,12 +81,12 @@ For example: ```python from torchao.quantization import ( quantize_, - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import ( FakeQuantizeConfig, - from_intx_quantization_aware_training, - intx_quantization_aware_training, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, ) model = get_model() @@ -96,7 +96,7 @@ activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=Fal weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( model, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) # train @@ -105,8 +105,8 @@ train_loop(model) # convert: transform fake quantization ops into actual quantized ops # swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts # quantized activation and weight tensor subclasses -quantize_(model, from_intx_quantization_aware_training()) -quantize_(model, int8_dynamic_activation_int4_weight(group_size=32)) +quantize_(model, FromIntXQuantizationAwareTrainingConfig()) +quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) # inference or generate ``` @@ -117,7 +117,7 @@ the following with a filter function during the prepare step: ``` quantize_( m, - intx_quantization_aware_training(weight_config=weight_config), + IntXQuantizationAwareTrainingConfig(weight_config=weight_config), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) ```