Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
def _apply(module, config_or_subclass_inserter):
if isinstance(config_or_subclass_inserter, AOBaseConfig):
quantize_(module, config_or_subclass_inserter)
else:
# TODO(#1690): delete this once config migration is done
module = config_or_subclass_inserter(module)
return module

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to("cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to(device="cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.cuda()

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
41 changes: 36 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand All @@ -46,6 +49,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
unwrap_tensor_subclass,
)

Expand Down Expand Up @@ -784,28 +788,55 @@ def test_int4wo_cpu(self, dtype, x_dim):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

# TODO(#1690): move to new config names
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int4_weight_only_numerics(self):
@common_utils.parametrize(
"config",
[
int4_weight_only(),
float8_weight_only(),
float8_dynamic_activation_float8_weight(),
float8_static_activation_float8_weight(scale=torch.tensor([1.0])),
],
)
def test_workflow_e2e_numerics(self, config):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
if (
isinstance(
config,
(
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
),
)
and not is_sm_at_least_89()
):
return unittest.skip("requires CUDA capability 8.9 or greater")

# scale has to be moved to cuda here because the parametrization init
# code happens before gating for cuda availability
if isinstance(config, float8_static_activation_float8_weight):
config.scale = config.scale.to("cuda")

# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_int4_wo = copy.deepcopy(m_ref)
m_q = copy.deepcopy(m_ref)

# quantize
quantize_(m_int4_wo, int4_weight_only())
quantize_(m_q, config)

with torch.no_grad():
y_ref = m_ref(x)
y_int4_wo = m_int4_wo(x)
y_q = m_q(x)

sqnr = compute_error(y_ref, y_int4_wo)
sqnr = compute_error(y_ref, y_q)
assert sqnr >= 20, f"SQNR {sqnr} is too low"


Expand Down
6 changes: 6 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
AffineQuantizedObserverBase,
)
from .quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
Expand Down Expand Up @@ -121,6 +124,9 @@
"gemlite_uintx_weight_only",
"swap_conv2d_1x1_to_linear",
"Int4WeightOnlyConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
"Float8StaticActivationFloat8WeightConfig",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
Loading
Loading