Skip to content

Commit cbcc080

Browse files
do conversion in fp8/fpx related handlers
1 parent cef444a commit cbcc080

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

torchao/quantization/quant_api.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,15 +1661,10 @@ def _float8_weight_only_transform(
16611661
"applying int8 weight only quant requires module to have weight attribute"
16621662
+ " but {module} does not have one"
16631663
)
1664-
# If model we're quantizing for inference was trained with torchao float8 training
1665-
# and checkpointed with the Float8Linears, we need to convert them back to
1666-
# regular nn.Linears so we can apply inference quantization techniques to them.
1664+
16671665
if isinstance(module, Float8Linear):
1668-
with torch.device("meta"):
1669-
new_module = nn.Linear(module.in_features, module.out_features)
1670-
new_module.weight = module.weight
1671-
new_module.bias = module.bias
1672-
module = new_module
1666+
module = _unwrap_float8_linear(module)
1667+
16731668
new_weight = _float8_weight_only_quant_tensor(module.weight, config)
16741669

16751670
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
@@ -1879,6 +1874,9 @@ def _float8_dynamic_activation_float8_weight_transform(
18791874
"applying float8 dynamic activation quant requires module to have weight attribute"
18801875
+ f"but {module} does not have one"
18811876
)
1877+
if isinstance(module, Float8Linear):
1878+
module = _unwrap_float8_linear(module)
1879+
18821880
quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor(
18831881
module.weight, config
18841882
)
@@ -1914,6 +1912,9 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
19141912
):
19151913
assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0"
19161914

1915+
if isinstance(module, Float8Linear):
1916+
module = _unwrap_float8_linear(module)
1917+
19171918
weight = module.weight
19181919
weight_dtype = config.weight_dtype
19191920
activation_dtype = config.activation_dtype
@@ -1978,6 +1979,9 @@ def _float8_static_activation_float8_weight_transform(
19781979
"Float8 static activation quantization is only supported on CUDA 8.9 and above"
19791980
)
19801981

1982+
if isinstance(module, Float8Linear):
1983+
module = _unwrap_float8_linear(module)
1984+
19811985
scale = config.scale
19821986
activation_dtype = config.activation_dtype
19831987
weight_dtype = config.weight_dtype
@@ -2337,6 +2341,9 @@ def _fpx_weight_only_transform(
23372341
if config.set_inductor_config:
23382342
torchao.quantization.utils.recommended_inductor_config_setter()
23392343

2344+
if isinstance(module, Float8Linear):
2345+
module = _unwrap_float8_linear(module)
2346+
23402347
from torchao.dtypes import to_affine_quantized_fpx
23412348
from torchao.dtypes.floatx import FloatxTensorCoreLayout
23422349

@@ -2395,6 +2402,21 @@ def _module_fqn_to_config_handler(
23952402
return module
23962403

23972404

2405+
def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear:
2406+
"""
2407+
Unwrap a torchao Float8Linear by returning a nn.Linear with the same weights and bias.
2408+
2409+
Torchao inference quantization techniques are generally only applicable to nn.Linear
2410+
layers, so this helper is useful for unwrapping models trained with torchao float8 training,
2411+
which replaces nn.Linear layers with Float8Linear layers.
2412+
"""
2413+
with torch.device("meta"):
2414+
new_module = nn.Linear(module.in_features, module.out_features)
2415+
new_module.weight = module.weight
2416+
new_module.bias = module.bias
2417+
return new_module
2418+
2419+
23982420
torch.serialization.add_safe_globals(
23992421
[
24002422
_int8_asymm_per_token_quant,

0 commit comments

Comments
 (0)