diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index fd866ae06c16..9b16fc2fbfee 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -925,12 +925,8 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 - if target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("adreno" in str(target.attrs)) - ): - # Keeping lower thread limit for this kernel on adreno target - # to avoid register spill - THREAD_LIMIT = 256 + if target.kind.name == "opencl" and "android" in str(target.host): + THREAD_LIMIT = 256 if H_kv < 8 else 512 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) @@ -1574,11 +1570,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = ( - 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), - d, - 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), - ) + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1588,12 +1580,6 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], tile_z = 8 num_warps = 2 - if target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("adreno" in str(target.attrs)) - ): - LOAD_VEC = 16 // ((DataType(dtype).bits + 7) // 8) # 16 bytes - NUM_BLKS = group_size * 8 - # fmt: off @T.prim_func def batch_prefill_ragged_kv( # pylint: disable=too-many-branches @@ -1722,6 +1708,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for lz, ly in T.grid(tile_z, tile_y): with T.block("K_load"): i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else( @@ -1836,14 +1824,6 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_prefill_ragged_kv) - get_extent = lambda *lps: [int(sch.get(lp).extent) for lp in lps] - - def get_vecsize(extent): - return min(LOAD_VEC, (extent & ~(extent - 1))) - - def getxy_vecsize(x, y, t): - assert (x * y) % t == 0 - return min(get_vecsize(y), get_vecsize(x * y // t)) def get_tile_size(x, y, t): cnt = (x * y) // t @@ -1857,37 +1837,26 @@ def get_tile_size(x, y, t): def apply_to_qkv_load(sch: tir.Schedule, block): loop_x, loop_y = sch.get_loops(block)[-2:] - x_extent, y_extent = get_extent(loop_x, loop_y) - vec_size = getxy_vecsize(x_extent, y_extent, bdx * num_warps) - yo, yv = sch.split(loop_y, [None, vec_size]) - yo_extent = y_extent // vec_size - tile_x, tile_y = get_tile_size(x_extent, yo_extent, (bdx * num_warps)) - xo, xi = sch.split(loop_x, [tile_x, None]) - yo, yi = sch.split(yo, [tile_y, None]) - sch.reorder(xi, yi, xo, yo) - t = sch.fuse(xi, yi) - ty, tx = sch.split(t, [num_warps, bdx]) + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - sch.vectorize(yv) + sch.vectorize(vec) def apply_to_so_ewise(sch: tir.Schedule, block, tile): loop_x, loop_y = sch.get_loops(block)[-2:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) - sch.unroll(xi) - yiv_extent = get_vecsize(tile[1]) - yio, yiv = sch.split(yi, [None, yiv_extent]) - sch.unroll(yio) - sch.vectorize(yiv) t = sch.fuse(xo, yo) ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=16, k_major=False + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False ): loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) @@ -1903,12 +1872,6 @@ def apply_to_gemm( # pylint: disable=unused-argument sch.reorder(ko, xi, yi, ki) else: sch.reorder(ko, ki, xi, yi) - yiv_extent = get_vecsize(tile[1]) - yio, yiv = sch.split(yi, [None, yiv_extent]) - sch.unroll(yio) - sch.vectorize(yiv) - sch.unroll(xi) - sch.unroll(ki) sch.decompose_reduction(block, ty) def apply_to_md(sch, block): @@ -1917,7 +1880,6 @@ def apply_to_md(sch, block): sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i)) tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)