Skip to content

Commit 837ee22

Browse files
committed
Fix _quantize_affine_no_dtype_cast for FP8 types
1 parent c53a9d5 commit 837ee22

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def _quantize_affine(
334334
zero_point,
335335
quant_min,
336336
quant_max,
337+
output_dtype,
337338
zero_point_domain,
338339
).to(output_dtype)
339340

@@ -345,6 +346,7 @@ def _quantize_affine_no_dtype_cast(
345346
zero_point: Optional[torch.Tensor],
346347
quant_min: Union[int, float],
347348
quant_max: Union[int, float],
349+
quant_dtype: Optional[torch.dtype],
348350
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
349351
) -> torch.Tensor:
350352
"""
@@ -389,7 +391,7 @@ def _quantize_affine_no_dtype_cast(
389391
assert (
390392
zero_point is None
391393
), "zero_point should be None when zero_point_domain is NONE"
392-
if _is_float8_type(input.dtype):
394+
if _is_float8_type(quant_dtype):
393395
quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
394396
else:
395397
quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
@@ -661,6 +663,7 @@ def _do_fake_quantize_affine(
661663
zero_point,
662664
quant_min,
663665
quant_max,
666+
quant_dtype,
664667
zero_point_domain.name,
665668
)
666669
dq = _dequantize_affine_no_dtype_check(

0 commit comments

Comments
 (0)