File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -334,6 +334,7 @@ def _quantize_affine(
334334 zero_point ,
335335 quant_min ,
336336 quant_max ,
337+ output_dtype ,
337338 zero_point_domain ,
338339 ).to (output_dtype )
339340
@@ -345,6 +346,7 @@ def _quantize_affine_no_dtype_cast(
345346 zero_point : Optional [torch .Tensor ],
346347 quant_min : Union [int , float ],
347348 quant_max : Union [int , float ],
349+ quant_dtype : Optional [torch .dtype ],
348350 zero_point_domain : Optional [str ] = ZeroPointDomain .INT .name ,
349351) -> torch .Tensor :
350352 """
@@ -389,7 +391,7 @@ def _quantize_affine_no_dtype_cast(
389391 assert (
390392 zero_point is None
391393 ), "zero_point should be None when zero_point_domain is NONE"
392- if _is_float8_type (input . dtype ):
394+ if _is_float8_type (quant_dtype ):
393395 quant = torch .clamp (input * scale .reciprocal (), quant_min , quant_max )
394396 else :
395397 quant = torch .clamp (torch .round (input * (1.0 / scale )), quant_min , quant_max )
@@ -661,6 +663,7 @@ def _do_fake_quantize_affine(
661663 zero_point ,
662664 quant_min ,
663665 quant_max ,
666+ quant_dtype ,
664667 zero_point_domain .name ,
665668 )
666669 dq = _dequantize_affine_no_dtype_check (
You can’t perform that action at this time.
0 commit comments