Skip to content

Commit 1001602

Browse files
committed
Update
[ghstack-poisoned]
2 parents 28f32b9 + 7d6bb6a commit 1001602

File tree

3 files changed

+104
-17
lines changed

3 files changed

+104
-17
lines changed

docs/source/tutorials_source/pt2e_quantizer.rst

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ Introduction
3232
Please see `here <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html#motivation-of-pytorch-2-export-quantization>`__ For motivations for the new API and ``Quantizer``.
3333

3434
An existing quantizer object defined for ``XNNPACK`` is in
35-
`QNNPackQuantizer <https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/pt2e/quantizer/xnnpack_quantizer.py>`__
35+
`XNNPackQuantizer <https://github.com/pytorch/executorch/blob/752f6a729d3a2090b43ace6915086d8b4e03644f/backends/xnnpack/quantizer/xnnpack_quantizer.py>`__
3636

3737
Annotation API
3838
^^^^^^^^^^^^^^^^^^^
3939

4040
``Quantizer`` uses annotation API to convey quantization intent for different operators/patterns.
4141
Annotation API mainly consists of
42-
`QuantizationSpec <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L38>`__
42+
`QuantizationSpec <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/quantizer.py#L40>`__
4343
and
44-
`QuantizationAnnotation <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L144>`__.
44+
`QuantizationAnnotation <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/quantizer.py#L121>`__.
4545

4646
``QuantizationSpec`` is used to convey intent of how a tensor will be quantized,
4747
e.g. dtype, bitwidth, min, max values, symmetric vs. asymmetric etc.
@@ -133,7 +133,7 @@ parameters can be shared among some tensors explicitly. Two typical use cases ar
133133

134134
- Example 1: One example is for ``add`` where having both inputs sharing quantization
135135
parameters makes operator implementation much easier. Without using of
136-
`SharedQuantizationSpec <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L90>`__,
136+
`SharedQuantizationSpec <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/quantizer.py#L97>`__,
137137
we must annotate ``add`` as example in above section 1, in which two inputs of ``add``
138138
has different quantization parameters.
139139
- Example 2: Another example is that of sharing quantization parameters between inputs and output.
@@ -211,7 +211,7 @@ as this:
211211
Another typical use case to annotate a quantized model is for tensors whose
212212
quantization parameters are known beforehand. For example, operator like ``sigmoid``, which has
213213
predefined and fixed scale/zero_point at input and output tensors.
214-
`FixedQParamsQuantizationSpec <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L90>`__
214+
`FixedQParamsQuantizationSpec <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/quantizer.py#L76>`__
215215
is designed for this use case. To use ``FixedQParamsQuantizationSpec``, users need to pass in parameters
216216
of ``scale`` and ``zero_point`` explicitly.
217217

@@ -243,14 +243,14 @@ of ``scale`` and ``zero_point`` explicitly.
243243
Another use case is to define the constraint for tensors whose quantization parameters are derived from other tensors.
244244
For example, if we want to annotate a convolution node, and define the ``scale`` of its bias input tensor
245245
as product of the activation tensor's ``scale`` and weight tensor's ``scale``. We can use
246-
`DerivedQuantizationSpec <https://github.com/pytorch/pytorch/blob/1ca2e993af6fa6934fca35da6970308ce227ddc7/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L102>`__
246+
`DerivedQuantizationSpec <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/quantizer.py#L107>`__
247247
to annotate this conv node.
248248

249249
- Step 1: Identify the original floating point pattern in the FX graph. We can use the same
250250
methods introduced in ``QuantizationSpec`` example to identify the ``convolution`` pattern.
251251
- Step 2: Define ``derive_qparams_fn`` function, it accepts list of ``ObserverOrFakeQuantize`` (
252-
`ObserverBase <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/ao/quantization/observer.py#L124>`__
253-
or `FakeQuantizeBase <https://github.com/pytorch/pytorch/blob/07104ca99c9d297975270fb58fda786e60b49b38/torch/ao/quantization/fake_quantize.py#L60>`__)
252+
`ObserverBase <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/observer.py#L157>`__
253+
or `FakeQuantizeBase <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/fake_quantize.py#L78>`__)
254254
as input. From each ``ObserverOrFakeQuantize`` object, user can get the ``scale``, ``zero point`` value.
255255
User can define its heuristic about how to derive new ``scale``, ``zero point`` value based on the
256256
quantization parameters calculated from the observer or fake quant instances.
@@ -293,13 +293,13 @@ and run a `toy example <https://gist.github.com/leslie-fang-intel/b78ed682aa9b54
293293
with ``Torchvision Resnet18``. To better understand the final example, here are the classes and utility
294294
functions that are used in the example:
295295

296-
- `QuantizationConfig <https://github.com/pytorch/pytorch/blob/73fd7235ad25ff061c087fa4bafc6e8df4d9c299/torch/ao/quantization/_pt2e/quantizer/quantizer.py#L103-L109>`__
296+
- `QuantizationConfig <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/utils.py#L21>`__
297297
consists of ``QuantizationSpec`` for activation, weight, and bias separately.
298298
- When annotating the model,
299-
`get_input_act_qspec <https://github.com/pytorch/pytorch/blob/47cfcf566ab76573452787335f10c9ca185752dc/torch/ao/quantization/_pt2e/quantizer/utils.py#L10>`__,
300-
`get_output_act_qspec <https://github.com/pytorch/pytorch/blob/47cfcf566ab76573452787335f10c9ca185752dc/torch/ao/quantization/_pt2e/quantizer/utils.py#L23>`__,
301-
`get_weight_qspec <https://github.com/pytorch/pytorch/blob/47cfcf566ab76573452787335f10c9ca185752dc/torch/ao/quantization/_pt2e/quantizer/utils.py#L36>`__, and
302-
`get_bias_qspec <https://github.com/pytorch/pytorch/blob/47cfcf566ab76573452787335f10c9ca185752dc/torch/ao/quantization/_pt2e/quantizer/utils.py#L53>`__
299+
`get_input_act_qspec <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/utils.py#L48>`__,
300+
`get_output_act_qspec <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/utils.py#L61>`__,
301+
`get_weight_qspec <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/utils.py#L74>`__, and
302+
`get_bias_qspec <https://github.com/pytorch/ao/blob/b96354087db6d0480ebbc10d5a63a9ca49c19dfa/torchao/quantization/pt2e/quantizer/utils.py#L92>`__
303303
can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific pattern.
304304

305305
A Note on IR for PT2E Quantization Flow
@@ -378,4 +378,4 @@ Conclusion
378378
With this tutorial, we introduce the new quantization path in PyTorch 2. Users can learn about
379379
how to define a ``BackendQuantizer`` with the ``QuantizationAnnotation API`` and integrate it into the PyTorch 2 Export Quantization flow.
380380
Examples of ``QuantizationSpec``, ``SharedQuantizationSpec``, ``FixedQParamsQuantizationSpec``, and ``DerivedQuantizationSpec``
381-
are given for specific annotation use case. You can use `XNNPACKQuantizer <https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/quantizer/xnnpack_quantizer.py>`_ as an example to start implementing your own ``Quantizer``. After that please follow `this tutorial <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_ to actually quantize your model.
381+
are given for specific annotation use case. You can use `XNNPACKQuantizer <https://github.com/pytorch/executorch/blob/752f6a729d3a2090b43ace6915086d8b4e03644f/backends/xnnpack/quantizer/xnnpack_quantizer.py>`_ as an example to start implementing your own ``Quantizer``. After that please follow `this tutorial <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_ to actually quantize your model.

torchao/float8/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
ScalingGranularity,
77
ScalingType,
88
)
9-
from torchao.float8.float8_linear_utils import convert_to_float8_training
9+
from torchao.float8.float8_linear_utils import (
10+
_auto_filter_for_recipe,
11+
convert_to_float8_training,
12+
)
1013
from torchao.float8.float8_tensor import (
1114
Float8Tensor,
1215
GemmInputRole,
@@ -44,6 +47,7 @@
4447
# top level UX
4548
"convert_to_float8_training",
4649
"precompute_float8_dynamic_scale_for_fsdp",
50+
"_auto_filter_for_recipe",
4751
# types
4852
"FP8Granularity",
4953
# note: Float8Tensor and Float8Linear are not public APIs

torchao/float8/float8_linear_utils.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import logging
7-
from typing import Callable, Optional
7+
from functools import partial
8+
from typing import Callable, List, Optional, Union
89

910
import torch.nn as nn
1011

11-
from torchao.float8.config import Float8LinearConfig
12+
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
1213
from torchao.float8.float8_linear import Float8Linear
1314

1415
log = logging.getLogger(__name__)
@@ -113,3 +114,85 @@ def convert_to_float8_training(
113114
from_float,
114115
module_filter_fn=module_filter_fn,
115116
)
117+
118+
119+
def _auto_filter_for_recipe(
120+
recipe: Union[str, Float8LinearRecipeName], filter_fqns: List[str]
121+
) -> Callable[[nn.Module, str], bool]:
122+
"""Returns function which automatically filters nn.Linear modules that meet at least one of the following criteria:
123+
124+
1. Dims not divisible by 16 (hardware requirement for float8).
125+
2. Dim sizes below certain thresholds, which may result in worse performance.
126+
127+
NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
128+
for your model. For the best performance, we recommend defining your own module_filter_fn customized for
129+
your module, using the performance tables for the given float8 recipe here:
130+
https://github.com/pytorch/ao/tree/main/torchao/float8#performance). These benchmarks referenced for
131+
auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
132+
133+
This is an experimental API, the design may change in the future.
134+
"""
135+
if isinstance(recipe, str):
136+
recipe = Float8LinearRecipeName(recipe)
137+
if recipe == Float8LinearRecipeName.TENSORWISE:
138+
return partial(_auto_filter_for_tensorwise, filter_fqns=filter_fqns)
139+
elif recipe == Float8LinearRecipeName.ROWWISE:
140+
return partial(_auto_filter_for_rowwise, filter_fqns=filter_fqns)
141+
elif recipe == Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
142+
raise NotImplementedError(f"Unsupported recipe: {recipe}")
143+
else:
144+
raise ValueError(f"Invalid recipe: {recipe}")
145+
146+
147+
def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -> bool:
148+
if not isinstance(mod, nn.Linear):
149+
return False
150+
151+
# If the fqn matches any filtered fqn, then we should not convert this module.
152+
is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
153+
if is_filtered_fqn:
154+
return False
155+
156+
# All dims must be divisible by 16 due to float8 hardware requirements.
157+
N, K = mod.weight.shape
158+
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
159+
if not dims_multiples_of_16:
160+
return False
161+
162+
# Dims below these thresholds may result in worse performance
163+
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling)
164+
# Note that these benchmarks referenced for auto filtering layers were run on
165+
# H100 GPUs, and may not be representative of other hardware.
166+
if N <= 2048:
167+
return False
168+
elif K <= 1024:
169+
return False
170+
elif N <= 4096 and K <= 2048:
171+
return False
172+
return True
173+
174+
175+
def _auto_filter_for_tensorwise(
176+
mod: nn.Module, fqn: str, filter_fqns: List[str]
177+
) -> bool:
178+
if not isinstance(mod, nn.Linear):
179+
return False
180+
181+
# If the fqn matches any filtered fqn, then we should not convert this module.
182+
is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
183+
if is_filtered_fqn:
184+
return False
185+
186+
# All dims must be divisible by 16 due to float8 hardware requirements.
187+
N, K = mod.weight.shape
188+
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
189+
if not dims_multiples_of_16:
190+
return False
191+
192+
# Dims below these thresholds may result in worse performance
193+
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#tensorwise-scaling)
194+
# Note that these benchmarks referenced for auto filtering layers were run on
195+
# H100 GPUs, and may not be representative of other hardware.
196+
if K <= 4096 and N <= 1024:
197+
return False
198+
return True

0 commit comments

Comments
 (0)