From a68fb7e4c200fa041900d04408f095eb7267a220 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 1 Mar 2024 13:43:23 -0500 Subject: [PATCH] [Dlight] Skip GeMV when normalization fails Prior to this PR, GeMV does not skip the cases of normalization failure, which leads to error. This PR fixes this issue. A unit test is added accordingly. --- python/tvm/dlight/gpu/gemv.py | 2 ++ tests/python/dlight/test_gpu_gemv.py | 33 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index d453b84bc055..d1a195fbad6f 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -180,6 +180,8 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None: + return None if len(block_infos) == 1: epilogue = None elif len(block_infos) == 2: diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index b5e8b82ab7e3..8903babbc0b4 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -996,5 +996,38 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f tvm.ir.assert_structural_equal(mod["main"], expected) +def test_func_to_skip(): + @T.prim_func + def before(var_A: T.handle, var_exclusive_scan_thrust: T.handle, seq_len: T.int64): + data_buf = T.match_buffer(var_A, (seq_len * T.int64(8),), "int32", align=8) + output_buf = T.match_buffer( + var_exclusive_scan_thrust, (seq_len * T.int64(8),), "int32", align=8 + ) + with T.block("exclusive_scan_thrust"): + T.reads() + T.writes() + T.call_packed( + "tvm.contrib.thrust.sum_scan", + T.tvm_stack_make_array( + data_buf.data, T.tvm_stack_make_shape(seq_len * T.int64(8)), 0, 1, 0, T.int64(0) + ), + T.tvm_stack_make_array( + output_buf.data, + T.tvm_stack_make_shape(seq_len * T.int64(8)), + 0, + 1, + 0, + T.int64(0), + ), + T.bool(False), + ) + + # This function should be skipped. + mod = tvm.IRModule({"main": before}) + with Target("metal"): + mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], before) + + if __name__ == "__main__": tvm.testing.main()