diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index f66001d0e883..b8a0bad0ca08 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -20,6 +20,7 @@ from functools import reduce import tvm +from tvm import DataType from tvm.relax import transform from tvm.relax.transform import PatternCheckContext @@ -68,11 +69,30 @@ def _check_matmul(context: PatternCheckContext) -> bool: # Rows number must be multiples of 4 for IGEMM return False elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": - # Matrix dimensions must be multiples of 16. This requirement is missing from the cuBLAS - # docs, but it was observed during testing. - if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 16 != 0: + matmul_rhs_var = matmul_call.args[1] + rhs_transposed = False + if matmul_rhs_var in context.matched_bindings: + matmul_rhs_call = context.matched_bindings[matmul_rhs_var] + assert ( + isinstance(matmul_rhs_call, tvm.relax.Call) + and matmul_rhs_call.op.name == "relax.permute_dims" + ) + rhs_transposed = True + + if not rhs_transposed: + # cuBLAS FP8 operations require rhs being transposed return False - if not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int)) or rhs_shape[-2] % 16 != 0: + + # cuBLAS FP8 operations require all tensors being aligned to 16 bytes. + if ( + not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) + or rhs_shape[-1] % (16 // DataType(lhs_dtype).itemsize()) != 0 + ): + return False + if ( + not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int)) + or rhs_shape[-2] % (16 // DataType(out_dtype).itemsize()) != 0 + ): return False lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 11247b380123..4f357626b804 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -269,17 +269,21 @@ def test_matmul_fp8_offload( @pytest.mark.parametrize( - "M, N, K, out_dtype, partition_done", + "M, N, K, out_dtype, transposed_y, partition_done", [ - (15, 64, 32, "float32", True), - (15, 64, 32, "e4m3_float8", True), - (15, 64, 32, "e5m2_float8", False), - (16, 32, 60, "float32", False), - (16, 30, 64, "float32", False), + (15, 64, 32, "float32", True, True), + (15, 64, 32, "e4m3_float8", True, True), + (15, 64, 32, "e5m2_float8", True, False), + (16, 32, 60, "float32", True, False), + (16, 30, 64, "float32", True, False), + (16, 8, 16, "float16", True, True), + (16, 16, 16, "float16", False, False), ], ) -def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, partition_done): - mod = get_relax_matmul_module((M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=True) +def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, transposed_y, partition_done): + mod = get_relax_matmul_module( + (M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=transposed_y + ) mod = partition_for_cublas(mod) func_name = "relax_matmul_cublas" if partition_done else "R.matmul" assert func_name in mod["main"].script()