Skip to content

Commit 479111c

Browse files
committed
[Reland][CPU] Add ops for float8 linear
1 parent 5cbbd73 commit 479111c

File tree

4 files changed

+888
-1
lines changed

4 files changed

+888
-1
lines changed

test/test_ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
except RuntimeError:
4141
pytest.skip("torchao.ops not available")
4242

43+
from torchao.quantization import PerGroup, PerRow, PerTensor
44+
from torchao.quantization.quant_primitives import (
45+
_choose_scale_float8,
46+
_dequantize_affine_float8,
47+
_quantize_affine_float8,
48+
)
4349
from torchao.quantization.utils import (
50+
get_block_size,
4451
get_groupwise_affine_qparams,
4552
groupwise_affine_dequantize_tensor_from_qparams,
4653
groupwise_affine_quantize_tensor_from_qparams,
@@ -901,5 +908,91 @@ def _test_scaled_embedding_bag_cpu_helper(
901908
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
902909

903910

911+
@pytest.mark.skipif(
912+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
913+
reason="cpp kernels not built",
914+
)
915+
@pytest.mark.parametrize(
916+
"multi_hot, batch_size, vector_size, index_type",
917+
EMBEDINGBAG_TEST_PARAMS,
918+
ids=str,
919+
)
920+
def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type):
921+
_test_scaled_embedding_bag_cpu_helper(
922+
multi_hot, batch_size, vector_size, index_type, torch.int8
923+
)
924+
925+
926+
@pytest.mark.skipif(
927+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
928+
reason="cpp kernels not built",
929+
)
930+
@pytest.mark.parametrize(
931+
"multi_hot, batch_size, vector_size, index_type",
932+
EMBEDINGBAG_TEST_PARAMS,
933+
ids=str,
934+
)
935+
def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type):
936+
_test_scaled_embedding_bag_cpu_helper(
937+
multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn
938+
)
939+
940+
941+
@pytest.mark.skipif(
942+
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu")
943+
or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
944+
reason="cpp kernels not built",
945+
)
946+
@pytest.mark.skipif(
947+
not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+"
948+
)
949+
@pytest.mark.parametrize("shape", [(64, 64), (256, 256)])
950+
@pytest.mark.parametrize("bs", [1, 160])
951+
@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half])
952+
@pytest.mark.parametrize("bias", [True, False])
953+
@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)])
954+
@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)])
955+
def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity):
956+
in_feature, out_feature = shape
957+
if isinstance(x_granularity, PerGroup):
958+
if x_granularity.group_size >= in_feature:
959+
return
960+
if not isinstance(w_granularity, PerGroup):
961+
return
962+
if isinstance(w_granularity, PerGroup):
963+
if w_granularity.group_size >= in_feature:
964+
return
965+
m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval()
966+
b = m.bias
967+
x = torch.randn(bs, in_feature)
968+
x_block_size = get_block_size(x.shape, x_granularity)
969+
x_scale = _choose_scale_float8(
970+
x,
971+
float8_dtype=torch.float8_e4m3fn,
972+
block_size=x_block_size,
973+
)
974+
x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn)
975+
976+
w = m.weight.detach()
977+
w_block_size = get_block_size(w.shape, w_granularity)
978+
w_scale = _choose_scale_float8(
979+
w,
980+
float8_dtype=torch.float8_e4m3fn,
981+
block_size=w_block_size,
982+
)
983+
w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn)
984+
985+
x_dq = _dequantize_affine_float8(x_fp8, x_scale)
986+
w_dq = _dequantize_affine_float8(w_fp8, w_scale)
987+
ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype)
988+
989+
packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale)
990+
y = torch.ops.torchao.float8_linear_cpu(
991+
x_fp8, x_scale, packed_w, packed_scale, b, out_dtype
992+
)
993+
994+
torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2)
995+
996+
904997
if __name__ == "__main__":
905998
pytest.main(sys.argv)

0 commit comments

Comments
 (0)