diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index f07ee45f3729..651e09dc5232 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -281,7 +281,7 @@ def _sch_inner_spatial( # Schedule epilogue if epilogue_info is not None: epilogue = epilogue_info.block_rv - sch.reverse_compute_at(epilogue, bx) + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) if is_broadcast_epilogue(sch, block, epilogue): sch.set_scope(block, 0, "shared") _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name diff --git a/tests/python/dlight/test_gpu_reduction.py b/tests/python/dlight/test_gpu_reduction.py index 75d2eeeb0716..def124a9b29a 100644 --- a/tests/python/dlight/test_gpu_reduction.py +++ b/tests/python/dlight/test_gpu_reduction.py @@ -1006,5 +1006,88 @@ def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_ast assert_structural_equal(mod, Expected) +def test_gemv_dyn_shape_epilogue(): + @I.ir_module + class Module: + @T.prim_func(private=True) + def main( + var_A: T.handle, + B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + var_C: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + vocab_size = T.int64() + A = T.match_buffer(var_A, (T.int64(4096), vocab_size), "float16") + C = T.match_buffer(var_C, (T.int64(1), T.int64(1), vocab_size)) + C_temp = T.alloc_buffer((T.int64(1), T.int64(1), vocab_size), "float16") + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), vocab_size, T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(B[v_i0, v_i1, v_k], A[v_k, v_i2]) + T.writes(C_temp[v_i0, v_i1, v_i2]) + with T.init(): + C_temp[v_i0, v_i1, v_i2] = T.float16(0) + C_temp[v_i0, v_i1, v_i2] = ( + C_temp[v_i0, v_i1, v_i2] + B[v_i0, v_i1, v_k] * A[v_k, v_i2] + ) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), vocab_size): + with T.block("epilogue"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(C_temp[v_i0, v_i1, v_i2]) + T.writes(C[v_i0, v_i1, v_i2]) + C[v_i0, v_i1, v_i2] = T.Cast("float32", C_temp[v_i0, v_i1, v_i2]) + + # fmt: off + @I.ir_module + class Expected: + @T.prim_func(private=True) + def main(var_A: T.handle, B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_C: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + vocab_size = T.int64() + A = T.match_buffer(var_A, (T.int64(4096), vocab_size), "float16") + C = T.match_buffer(var_C, (T.int64(1), T.int64(1), vocab_size)) + # with T.block("root"): + C_temp_local = T.alloc_buffer((T.int64(1), T.int64(1), vocab_size), "float16", scope="local") + C_temp_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(1), vocab_size), "float16", scope="local") + for ax0_fused_0 in T.thread_binding(vocab_size, thread="blockIdx.x"): + for ax0_fused_1 in T.thread_binding(T.int64(1), thread="threadIdx.x"): + for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + with T.block("matmul_rf_init"): + vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1) + v0 = T.axis.spatial(vocab_size, ax0_fused_0 + ax0_fused_1) + T.reads() + T.writes(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0]) + C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_fused_0, u in T.grid(T.int64(256), 1): + with T.block("matmul_rf_update"): + vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1) + v0 = T.axis.spatial(vocab_size, ax0_fused_0 + ax0_fused_1) + vax1_fused_0 = T.axis.reduce(T.int64(256), ax1_fused_0) + T.reads(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0], B[T.int64(0), T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], A[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]) + T.writes(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0]) + C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] = C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] + B[T.int64(0), T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * A[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0] + for ax1_fused in T.thread_binding(T.int64(1), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + with T.block("matmul"): + vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused_0]) + T.reads(C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(C_temp_local[T.int64(0), T.int64(0), v0]) + with T.init(): + C_temp_local[T.int64(0), T.int64(0), v0] = T.float16(0) + C_temp_local[T.int64(0), T.int64(0), v0] = C_temp_local[T.int64(0), T.int64(0), v0] + C_temp_rf_local[vax1_fused_1, T.int64(0), T.int64(0), v0] + for ax0_fused_0_1 in T.thread_binding(T.int64(1), thread="threadIdx.x"): + for ax0_fused_1 in range(T.int64(1)): + with T.block("epilogue"): + v0 = T.axis.spatial(vocab_size, ax0_fused_0) + T.reads(C_temp_local[T.int64(0), T.int64(0), v0]) + T.writes(C[T.int64(0), T.int64(0), v0]) + C[T.int64(0), T.int64(0), v0] = T.Cast("float32", C_temp_local[T.int64(0), T.int64(0), v0]) + # fmt: on + + with Target("nvidia/geforce-rtx-3090-ti"): + mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Module) # pylint: disable=not-callable + assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()