@@ -1661,15 +1661,10 @@ def _float8_weight_only_transform(
1661
1661
"applying int8 weight only quant requires module to have weight attribute"
1662
1662
+ " but {module} does not have one"
1663
1663
)
1664
- # If model we're quantizing for inference was trained with torchao float8 training
1665
- # and checkpointed with the Float8Linears, we need to convert them back to
1666
- # regular nn.Linears so we can apply inference quantization techniques to them.
1664
+
1667
1665
if isinstance (module , Float8Linear ):
1668
- with torch .device ("meta" ):
1669
- new_module = nn .Linear (module .in_features , module .out_features )
1670
- new_module .weight = module .weight
1671
- new_module .bias = module .bias
1672
- module = new_module
1666
+ module = _unwrap_float8_linear (module )
1667
+
1673
1668
new_weight = _float8_weight_only_quant_tensor (module .weight , config )
1674
1669
1675
1670
module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
@@ -1879,6 +1874,9 @@ def _float8_dynamic_activation_float8_weight_transform(
1879
1874
"applying float8 dynamic activation quant requires module to have weight attribute"
1880
1875
+ f"but { module } does not have one"
1881
1876
)
1877
+ if isinstance (module , Float8Linear ):
1878
+ module = _unwrap_float8_linear (module )
1879
+
1882
1880
quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor (
1883
1881
module .weight , config
1884
1882
)
@@ -1914,6 +1912,9 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
1914
1912
):
1915
1913
assert is_sm_at_least_90 (), "Float8 quantization is only supported on CUDA>=9.0"
1916
1914
1915
+ if isinstance (module , Float8Linear ):
1916
+ module = _unwrap_float8_linear (module )
1917
+
1917
1918
weight = module .weight
1918
1919
weight_dtype = config .weight_dtype
1919
1920
activation_dtype = config .activation_dtype
@@ -1978,6 +1979,9 @@ def _float8_static_activation_float8_weight_transform(
1978
1979
"Float8 static activation quantization is only supported on CUDA 8.9 and above"
1979
1980
)
1980
1981
1982
+ if isinstance (module , Float8Linear ):
1983
+ module = _unwrap_float8_linear (module )
1984
+
1981
1985
scale = config .scale
1982
1986
activation_dtype = config .activation_dtype
1983
1987
weight_dtype = config .weight_dtype
@@ -2337,6 +2341,9 @@ def _fpx_weight_only_transform(
2337
2341
if config .set_inductor_config :
2338
2342
torchao .quantization .utils .recommended_inductor_config_setter ()
2339
2343
2344
+ if isinstance (module , Float8Linear ):
2345
+ module = _unwrap_float8_linear (module )
2346
+
2340
2347
from torchao .dtypes import to_affine_quantized_fpx
2341
2348
from torchao .dtypes .floatx import FloatxTensorCoreLayout
2342
2349
@@ -2395,6 +2402,21 @@ def _module_fqn_to_config_handler(
2395
2402
return module
2396
2403
2397
2404
2405
+ def _unwrap_float8_linear (module : Float8Linear ) -> nn .Linear :
2406
+ """
2407
+ Unwrap a torchao Float8Linear by returning a nn.Linear with the same weights and bias.
2408
+
2409
+ Torchao inference quantization techniques are generally only applicable to nn.Linear
2410
+ layers, so this helper is useful for unwrapping models trained with torchao float8 training,
2411
+ which replaces nn.Linear layers with Float8Linear layers.
2412
+ """
2413
+ with torch .device ("meta" ):
2414
+ new_module = nn .Linear (module .in_features , module .out_features )
2415
+ new_module .weight = module .weight
2416
+ new_module .bias = module .bias
2417
+ return new_module
2418
+
2419
+
2398
2420
torch .serialization .add_safe_globals (
2399
2421
[
2400
2422
_int8_asymm_per_token_quant ,
0 commit comments