Skip to content

Commit 6cf26bd

Browse files
committed
Fix Float8Tensor quantize op kernrel preference dispatch
Summary: Previously we didn't handle kernel_preference == "fbgemm" properly for the quantize op, this PR makes sure we dispatch to fbgemm kernels when kernel_preference is fbgemm This doesn't have much impact on BC, the serialized checkpoints will use AUTO which is going to be dispatched to triton op for quantize, only thing is fixing the kernel choice for fbgemm kernel preference, which is supposed to be a developer facing API (we expect most users to just use AUTO without worrying about details) Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2883, branch: jerryzh168/stack/59
1 parent 9056c46 commit 6cf26bd

File tree

5 files changed

+47
-14
lines changed

5 files changed

+47
-14
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ def test_expected_kernels_on_gpu(self, granularity, float8_config_version):
789789
# three triton kernels for quantizing the activation:
790790
# kernel 1: x_max_tmp = max(x, ...)
791791
# kernel 2: x_max = max(x_max_tmp)
792-
# kernel 3: x_float8 = to_float8(x, x_max)
792+
# kernel 3: x_float8 = Float8Tensor.from_hp(x, x_max)
793793
FileCheck().check("def call(").check_count(".run(", 3, exactly=True).run(
794794
code[0]
795795
)

test/quantization/test_qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1859,7 +1859,7 @@ def test_float8_fake_quantize(self, granularity: Granularity):
18591859
torch.manual_seed(self.SEED)
18601860
x = torch.randn(32, 64)
18611861
out = fake_quantizer(x)
1862-
out_expected = Float8Tensor.to_float8(x, dtype, granularity).dequantize()
1862+
out_expected = Float8Tensor.from_hp(x, dtype, granularity).dequantize()
18631863
sqnr = compute_error(out, out_expected)
18641864
self.assertGreater(sqnr, 16)
18651865

torchao/quantization/quant_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,7 +1546,7 @@ def _float8_weight_only_quant_tensor(weight, config):
15461546
else:
15471547
assert config.version == 2, f"Unexpected version: {config.version}"
15481548
weight_dtype = config.weight_dtype
1549-
new_weight = Float8Tensor.to_float8(
1549+
new_weight = Float8Tensor.from_hp(
15501550
weight, float8_dtype=weight_dtype, granularity=PerRow()
15511551
)
15521552
return new_weight
@@ -1744,7 +1744,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
17441744
kernel_preference=kernel_preference,
17451745
)
17461746

1747-
quantized_weight = Float8Tensor.to_float8(
1747+
quantized_weight = Float8Tensor.from_hp(
17481748
weight,
17491749
float8_dtype=weight_dtype,
17501750
granularity=weight_granularity,

torchao/quantization/quantize_/common/quantize_tensor_kwargs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class QuantizeTensorKwargs(abc.ABC):
2222
2323
class Float8Tensor(...)
2424
@classmethod
25-
def to_float8(cls, tensor, quant_kwargs: QuantizeTensorKwargs)
25+
def from_hp(cls, tensor, quant_kwargs: QuantizeTensorKwargs)
2626
...
2727
"""
2828

@@ -43,7 +43,7 @@ def _choose_quant_func_and_quantize_tensor(
4343
)
4444

4545
if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs):
46-
return Float8Tensor.to_float8(
46+
return Float8Tensor.from_hp(
4747
tensor,
4848
quant_kwargs.float8_dtype,
4949
quant_kwargs.granularity,

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
preprocess_data,
2323
preprocess_scale,
2424
)
25-
from torchao.quantization.granularity import PerRow
25+
from torchao.quantization.granularity import PerRow, PerTensor
2626
from torchao.quantization.observer import get_block_size
2727
from torchao.quantization.quant_primitives import (
2828
_choose_scale_float8,
@@ -163,7 +163,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
163163
return _dequantize_affine_float8(qdata, scale, output_dtype)
164164

165165
@classmethod
166-
def to_float8(
166+
def from_hp(
167167
cls,
168168
hp_tensor: torch.Tensor,
169169
float8_dtype: torch.dtype = torch.float8_e4m3fn,
@@ -177,18 +177,29 @@ def to_float8(
177177
block_size = get_block_size(hp_tensor.shape, granularity)
178178
block_size = list(block_size)
179179

180+
kernel_choice = None
180181
# for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance
181182
if (
182183
kernel_preference == KernelPreference.AUTO
183184
and _is_fbgemm_genai_gpu_available()
184-
and (
185-
tuple(block_size)
186-
== (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],)
187-
)
185+
and is_sm_at_least_90()
186+
and isinstance(granularity, PerRow)
187+
and float8_dtype == torch.float8_e4m3fn
188+
and hp_value_lb is None
188189
):
189-
assert float8_dtype == torch.float8_e4m3fn, (
190-
f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}"
190+
# optimized path for auto and per row quantization
191+
kernel_choice = "triton"
192+
elif kernel_preference == KernelPreference.FBGEMM and hp_value_lb is None:
193+
assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), (
194+
"Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (> H100)"
191195
)
196+
kernel_choice = "fbgemm"
197+
else:
198+
# fallback path for everything else will be torch
199+
kernel_choice = "torch"
200+
201+
if kernel_choice == "triton":
202+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
192203
if hp_value_ub is not None:
193204
maybe_hp_value_ub_tensor = torch.tensor(
194205
hp_value_ub, dtype=torch.float, device=hp_tensor.device
@@ -202,7 +213,29 @@ def to_float8(
202213
for i in range(hp_tensor.ndim):
203214
scale_shape.append(hp_tensor.shape[i] // block_size[i])
204215
scale = scale.reshape(*scale_shape)
216+
elif kernel_choice == "fbgemm":
217+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
218+
if hp_value_ub is not None:
219+
maybe_hp_value_ub_tensor = torch.tensor(
220+
hp_value_ub, dtype=torch.float, device=hp_tensor.device
221+
)
222+
else:
223+
maybe_hp_value_ub_tensor = None
224+
# not used
225+
num_tokens = torch.empty([hp_tensor.size(0)], device=hp_tensor.device)
226+
if isinstance(granularity, PerRow):
227+
data, scale = torch.ops.fbgemm.quantize_fp8_per_row(
228+
hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
229+
)
230+
else:
231+
assert isinstance(granularity, PerTensor), (
232+
f"Expected per tensor, got {granularity}"
233+
)
234+
data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor(
235+
hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
236+
)
205237
else:
238+
assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}"
206239
scale = _choose_scale_float8(
207240
hp_tensor,
208241
float8_dtype=float8_dtype,

0 commit comments

Comments
 (0)