|
40 | 40 | except RuntimeError:
|
41 | 41 | pytest.skip("torchao.ops not available")
|
42 | 42 |
|
| 43 | +from torchao.quantization import PerGroup, PerRow, PerTensor |
| 44 | +from torchao.quantization.quant_primitives import ( |
| 45 | + _choose_scale_float8, |
| 46 | + _dequantize_affine_float8, |
| 47 | + _quantize_affine_float8, |
| 48 | +) |
43 | 49 | from torchao.quantization.utils import (
|
| 50 | + get_block_size, |
44 | 51 | get_groupwise_affine_qparams,
|
45 | 52 | groupwise_affine_dequantize_tensor_from_qparams,
|
46 | 53 | groupwise_affine_quantize_tensor_from_qparams,
|
@@ -901,5 +908,91 @@ def _test_scaled_embedding_bag_cpu_helper(
|
901 | 908 | torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
|
902 | 909 |
|
903 | 910 |
|
| 911 | +@pytest.mark.skipif( |
| 912 | + "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), |
| 913 | + reason="cpp kernels not built", |
| 914 | +) |
| 915 | +@pytest.mark.parametrize( |
| 916 | + "multi_hot, batch_size, vector_size, index_type", |
| 917 | + EMBEDINGBAG_TEST_PARAMS, |
| 918 | + ids=str, |
| 919 | +) |
| 920 | +def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type): |
| 921 | + _test_scaled_embedding_bag_cpu_helper( |
| 922 | + multi_hot, batch_size, vector_size, index_type, torch.int8 |
| 923 | + ) |
| 924 | + |
| 925 | + |
| 926 | +@pytest.mark.skipif( |
| 927 | + "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), |
| 928 | + reason="cpp kernels not built", |
| 929 | +) |
| 930 | +@pytest.mark.parametrize( |
| 931 | + "multi_hot, batch_size, vector_size, index_type", |
| 932 | + EMBEDINGBAG_TEST_PARAMS, |
| 933 | + ids=str, |
| 934 | +) |
| 935 | +def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type): |
| 936 | + _test_scaled_embedding_bag_cpu_helper( |
| 937 | + multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn |
| 938 | + ) |
| 939 | + |
| 940 | + |
| 941 | +@pytest.mark.skipif( |
| 942 | + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu") |
| 943 | + or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), |
| 944 | + reason="cpp kernels not built", |
| 945 | +) |
| 946 | +@pytest.mark.skipif( |
| 947 | + not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+" |
| 948 | +) |
| 949 | +@pytest.mark.parametrize("shape", [(64, 64), (256, 256)]) |
| 950 | +@pytest.mark.parametrize("bs", [1, 160]) |
| 951 | +@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half]) |
| 952 | +@pytest.mark.parametrize("bias", [True, False]) |
| 953 | +@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)]) |
| 954 | +@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)]) |
| 955 | +def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity): |
| 956 | + in_feature, out_feature = shape |
| 957 | + if isinstance(x_granularity, PerGroup): |
| 958 | + if x_granularity.group_size >= in_feature: |
| 959 | + return |
| 960 | + if not isinstance(w_granularity, PerGroup): |
| 961 | + return |
| 962 | + if isinstance(w_granularity, PerGroup): |
| 963 | + if w_granularity.group_size >= in_feature: |
| 964 | + return |
| 965 | + m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval() |
| 966 | + b = m.bias |
| 967 | + x = torch.randn(bs, in_feature) |
| 968 | + x_block_size = get_block_size(x.shape, x_granularity) |
| 969 | + x_scale = _choose_scale_float8( |
| 970 | + x, |
| 971 | + float8_dtype=torch.float8_e4m3fn, |
| 972 | + block_size=x_block_size, |
| 973 | + ) |
| 974 | + x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn) |
| 975 | + |
| 976 | + w = m.weight.detach() |
| 977 | + w_block_size = get_block_size(w.shape, w_granularity) |
| 978 | + w_scale = _choose_scale_float8( |
| 979 | + w, |
| 980 | + float8_dtype=torch.float8_e4m3fn, |
| 981 | + block_size=w_block_size, |
| 982 | + ) |
| 983 | + w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn) |
| 984 | + |
| 985 | + x_dq = _dequantize_affine_float8(x_fp8, x_scale) |
| 986 | + w_dq = _dequantize_affine_float8(w_fp8, w_scale) |
| 987 | + ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype) |
| 988 | + |
| 989 | + packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale) |
| 990 | + y = torch.ops.torchao.float8_linear_cpu( |
| 991 | + x_fp8, x_scale, packed_w, packed_scale, b, out_dtype |
| 992 | + ) |
| 993 | + |
| 994 | + torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2) |
| 995 | + |
| 996 | + |
904 | 997 | if __name__ == "__main__":
|
905 | 998 | pytest.main(sys.argv)
|
0 commit comments