From 3a1c09745a12b6d8d59c078e359e2c02724272ef Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Fri, 9 Aug 2024 10:25:42 +0530 Subject: [PATCH] [DLIGHT][ADRENO] Fix for opencl adreno matmul schedule Fixed the matmul schedule for the case of epilog blocks --- python/tvm/dlight/gpu/matmul.py | 50 +++++++++++---- tests/python/dlight/test_gpu_matmul.py | 89 ++++++++++++++------------ 2 files changed, 85 insertions(+), 54 deletions(-) diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 25cc649b44dd..5fb8e2469d54 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -941,7 +941,7 @@ def get_configs(self, target: Target) -> Config: inner_x=False, ) elif target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("windows" in str(target.host)) + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) ): return Matmul.Config( block_size_x=32, @@ -991,7 +991,10 @@ def is_inner_reduction(block_stmt, iter_infos): end_it = block_stmt.reads[-1].region[-1].min return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == "R" - if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, iter_infos): + if ( + target.kind.name == "opencl" + and (("android" in str(target.host)) or ("adreno" in str(target.attrs))) + ) and not is_inner_reduction(block_stmt, iter_infos): ret = self.sch_outer_reduction(sch, config, main_block, blocks) if ret is not None: return ret @@ -1122,6 +1125,16 @@ def sch_outer_reduction( reduction_block: tir.schedule.BlockRV, blocks: List[tir.schedule.BlockRV], ) -> Optional[tir.Schedule]: + + """Get vectorization factor""" + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + reduction_loops = sch.get_loops(reduction_block) if not len(reduction_loops) == 4: return None @@ -1140,13 +1153,17 @@ def sch_outer_reduction( config.vector_size, config.unroll, ) - - is_dequant_block = len(blocks) > 1 - if is_dequant_block: - compute_block, dequant_block, matmul_block = blocks - sch.compute_inline(compute_block) - else: - (matmul_block,) = blocks + VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 8]), VecSize) + dequant_block = None + matmul_block = reduction_block + epilogue_block = None + if blocks[-1] is not matmul_block: + epilogue_block = blocks[-1] + for blk in blocks[:-1]: + if "dequantize" in sch.get(blk).name_hint: + dequant_block = blk + elif blk is not matmul_block: + sch.compute_inline(blk) m = sch.fuse(mb, ms) @@ -1162,12 +1179,13 @@ def sch_outer_reduction( sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv) sch.compute_at(rmat_block, k0) - if is_dequant_block: + if dequant_block is not None: sch.compute_at(dequant_block, k3) sch.reverse_compute_at(wmat_block, mi) sch.set_scope(rmat_block, 0, "shared") sch.set_scope(matmul_block, 0, "local") - if is_dequant_block: + + if dequant_block is not None: sch.set_scope(dequant_block, 0, "local") sch.bind(mo, "blockIdx.y") @@ -1175,7 +1193,7 @@ def sch_outer_reduction( sch.bind(mi, "threadIdx.y") sch.bind(ni, "threadIdx.x") sch.vectorize(sch.get_loops(matmul_block)[-1]) - if is_dequant_block: + if dequant_block is not None: sch.vectorize(sch.get_loops(dequant_block)[-1]) # Co-operative Memory Fetch @@ -1187,7 +1205,7 @@ def sch_outer_reduction( sch.vectorize(wv) # Scale and Quant Cache - if is_dequant_block: + if dequant_block is not None: qb = sch.cache_read(dequant_block, 0, "local") sb = sch.cache_read(dequant_block, 1, "local") sch.compute_at(sb, k1) @@ -1197,5 +1215,11 @@ def sch_outer_reduction( sch.vectorize(sch.get_loops(qb)[-1]) sch.vectorize(sch.get_loops(sb)[-1]) + if epilogue_block is not None: + sch.reverse_compute_at(epilogue_block, mi, preserve_unit_loops=True) + sch.set_scope(wmat_block, 0, "local") + sch.compute_inline(wmat_block) + sch.vectorize(sch.get_loops(epilogue_block)[-1]) + sch.decompose_reduction(matmul_block, k0) return sch diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index 4cef7f1c27c3..dc5276e62a5f 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -685,47 +685,54 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), class TestFusedDequantMatmulAndroid(AndroidBeforeAfter): # fmt: off @T.prim_func - def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle): + def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) seq_len = T.int64() - rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16") - matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") + rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16") + T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") # with T.block("root"): compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") + matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(12288)), "float16") for i0, i1 in T.grid(T.int64(4096), T.int64(12288)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(lv840[v_i0 // T.int64(8), v_i1]) + T.reads(lv452[v_i0 // T.int64(8), v_i1]) T.writes(compute[v_i0, v_i1]) - compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) + compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv452[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(4096), T.int64(12288)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1]) + T.reads(compute[v_i0, v_i1], lv453[v_i0 // T.int64(32), v_i1]) T.writes(dequantize_intermediate_intermediate[v_i0, v_i1]) - dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1] + dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv453[v_i0 // T.int64(32), v_i1] for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), T.int64(4096)): with T.block("matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rms_norm260[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2]) + T.reads(rms_norm130[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2]) T.writes(matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2] + matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(12288)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(matmul_intermediate[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) + T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) + T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2] @T.prim_func - def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle): + def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle): T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) seq_len = T.int64() - rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16") - matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") + rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16") + T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") # with T.block("root"): dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local") - rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared") + rms_norm130_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared") matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", scope="local") - lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") - lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") + lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") + lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"): for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.y"): for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): @@ -743,37 +750,37 @@ def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T for ax0 in range(T.int64(4)): for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax1_1 in T.vectorized(T.int64(8)): - with T.block("rms_norm260_pad"): + with T.block("rms_norm130_pad"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1) - T.reads(rms_norm260[v0, v1, v2]) - T.writes(rms_norm260_pad_shared[v0, v1, v2]) - rms_norm260_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0)) + T.reads(rms_norm130[v0, v1, v2]) + T.writes(rms_norm130_pad_shared[v0, v1, v2]) + rms_norm130_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0)) for k_1 in range(T.int64(8)): for ax0 in T.vectorized(T.int64(8)): - with T.block("lv841_local"): + with T.block("lv453_local"): v0 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + k_1) v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) - T.reads(lv841[v0, v1]) - T.writes(lv841_local[v0, v1]) - lv841_local[v0, v1] = lv841[v0, v1] + T.reads(lv453[v0, v1]) + T.writes(lv453_local[v0, v1]) + lv453_local[v0, v1] = lv453[v0, v1] for k_2 in range(T.int64(4)): for ax0 in T.vectorized(T.int64(8)): - with T.block("lv840_local"): + with T.block("lv452_local"): v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) - T.reads(lv840[v0, v1]) - T.writes(lv840_local[v0, v1]) - lv840_local[v0, v1] = lv840[v0, v1] + T.reads(lv452[v0, v1]) + T.writes(lv452_local[v0, v1]) + lv452_local[v0, v1] = lv452[v0, v1] for k_3 in range(T.int64(8)): for ax0 in T.vectorized(T.int64(8)): with T.block("dequantize"): v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) - T.reads(lv840_local[v_i0 // T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1]) + T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1]) T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1]) - dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1] + dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1] for i0_i1_fused_2 in range(T.int64(4)): for i2_2 in T.vectorized(T.int64(8)): with T.block("matmul_update"): @@ -781,19 +788,19 @@ def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2) v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) - T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm260_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2]) + T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2]) T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2] - for ax0 in range(T.int64(4)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("matmul_intermediate_pad"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) - v2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) - T.where((i0_i1_fused_0 - (seq_len + T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len) - T.reads(matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(matmul_intermediate[v0, v1, v2]) - matmul_intermediate[v0, v1, v2] = matmul_intermediate_pad_local[v0, v1, v2] + matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2] + for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): + for ax2 in T.vectorized(T.int64(8)): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1) + v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax2) + T.where(i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1 < seq_len) + T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) + T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) + T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2] # fmt: on