From ee2183cff1ea764979da2e259001f9ba5c77a8cf Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 25 Jul 2024 00:28:25 -0400 Subject: [PATCH] [Pass] Rewrite FuseAddRMSNorm to avoid binding rewrite recursion This PR revamps the FuseAddRMSNorm pass with manual pattern matching, in purpose of avoiding `rewrite_bindings` which is recursive and may cause unaffordable time when the model is large. --- python/mlc_llm/compiler_pass/fuse_add_norm.py | 143 ++++++++++-------- 1 file changed, 78 insertions(+), 65 deletions(-) diff --git a/python/mlc_llm/compiler_pass/fuse_add_norm.py b/python/mlc_llm/compiler_pass/fuse_add_norm.py index 04adefc90d..60165ad8aa 100644 --- a/python/mlc_llm/compiler_pass/fuse_add_norm.py +++ b/python/mlc_llm/compiler_pass/fuse_add_norm.py @@ -1,16 +1,17 @@ """A compiler pass that fuses add + rms_norm.""" +# pylint: disable=invalid-name + +from typing import Optional + import tvm from tvm import relax -from tvm.relax.dpl import PatternContext, rewrite_bindings -from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator from tvm.script import tir as T from ..support.max_thread_check import get_max_num_threads_per_block -# mypy: disable-error-code="attr-defined,valid-type" -# pylint: disable=too-many-locals,invalid-name - def _get_add_rms_norm_decode(hidden_size: int, eps: float, TX: int): inv_hidden_size = T.float32(1.0 / float(hidden_size)) @@ -18,7 +19,9 @@ def _get_add_rms_norm_decode(hidden_size: int, eps: float, TX: int): add_local_size = hidden_size // TX @T.prim_func(private=True) - def decode_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle): + def decode_add_rms( # pylint: disable=too-many-locals + pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle + ): T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) batch_size = T.int32() A = T.match_buffer(pA, (batch_size, 1, hidden_size), "float16") @@ -81,7 +84,9 @@ def _get_add_rms_norm_prefill(hidden_size: int, eps: float, TX: int): add_local_size = hidden_size // TX @T.prim_func(private=True) - def prefill_add_rms(pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle): + def prefill_add_rms( # pylint: disable=too-many-locals + pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle + ): T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) seq_len = T.int32() A = T.match_buffer(pA, (1, seq_len, hidden_size), "float16") @@ -147,68 +152,76 @@ def __init__(self, target: tvm.target.Target) -> None: target : tvm.target.Target Target device. """ - self.TX = 1024 # default - - max_num_threads_per_block = get_max_num_threads_per_block(target) - if max_num_threads_per_block < self.TX: - self.TX = max_num_threads_per_block + self.target = target def transform_module(self, mod: tvm.IRModule, _ctx: tvm.transform.PassContext) -> tvm.IRModule: """IRModule-level transformation.""" - with PatternContext() as ctx: - pat_x1 = wildcard() - pat_x2 = wildcard() - pat_y = is_op("relax.add")(pat_x1, pat_x2) - pat_w = wildcard() - pat_o = is_op("relax.nn.rms_norm")(pat_y, pat_w) - - def rewriter(matchings, bindings): - x1 = matchings[pat_x1] - x2 = matchings[pat_x2] - weight = matchings[pat_w] - y = matchings[pat_y] - o = matchings[pat_o] - eps = bindings[o].attrs.epsilon - if x1.struct_info.dtype != "float16": - return {} - n, _, h = x1.struct_info.shape - func_name = "fuse_add_norm_prefill" if n == 1 else "fuse_add_norm_decode" - - if all(gv.name_hint != func_name for gv in mod.functions): - h = int(h) - if h % self.TX != 0: - return {} - if n == 1: - func = _get_add_rms_norm_prefill(h, eps, self.TX) - else: - func = _get_add_rms_norm_decode(h, eps, self.TX) - mod[func_name] = func - gvar = mod.get_global_var(func_name) - relax.expr._update_struct_info( # pylint: disable=protected-access - gvar, - relax.FuncStructInfo.opaque_func(ret=relax.ObjectStructInfo()), + return _FuseAddRMSNormRewriter(mod.clone(), self.target).transform() + + +@mutator +class _FuseAddRMSNormRewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: tvm.IRModule, target: tvm.target.Target): + super().__init__(mod) + self.mod = mod + self.prefill_norm_gv: Optional[tvm.ir.GlobalVar] = None + self.decode_norm_gv: Optional[tvm.ir.GlobalVar] = None + self.TX = min(1024, get_max_num_threads_per_block(target)) + + def transform(self) -> tvm.IRModule: # pylint: disable=too-many-locals + """Entry point of the transformation""" + for g_var, func in self.mod.functions_items(): + if not isinstance(func, relax.Function): + continue + new_func = self.visit_expr(func) + new_func = remove_all_unused(new_func) + self.builder_.update_func(g_var, new_func) + return self.builder_.finalize() + + def visit_call_(self, call: relax.Call) -> relax.Expr: # pylint: disable=arguments-renamed + call = super().visit_call_(call) + + # Match the "rms_norm(add(x1, x2), w)" pattern + if call.op != tvm.ir.Op.get("relax.nn.rms_norm") or call.struct_info.dtype != "float16": + return call + assert len(call.args) == 2 + weight = call.args[1] + eps = call.attrs.epsilon + assert isinstance(call.args[0], relax.Var) + y = self.lookup_binding(call.args[0]) + if not isinstance(y, relax.Call) or y.op != tvm.ir.Op.get("relax.add"): + return call + assert len(y.args) == 2 + x1 = y.args[0] + x2 = y.args[1] + # Extra check + n, _, h = x1.struct_info.shape + h = int(h) + if h % self.TX != 0: + return call + + is_prefill = n == 1 + func_gv = self.prefill_norm_gv if is_prefill else self.decode_norm_gv + if func_gv is None: + if is_prefill: + func_gv = self.builder_.add_func( + _get_add_rms_norm_prefill(h, eps, self.TX), "fuse_add_norm_prefill" ) + self.prefill_norm_gv = func_gv else: - gvar = mod.get_global_var(func_name) - o_y_tuple = relax.call_tir( - gvar, + func_gv = self.builder_.add_func( + _get_add_rms_norm_decode(h, eps, self.TX), "fuse_add_norm_decode" + ) + self.decode_norm_gv = func_gv + + tuple_output = self.builder_.emit( + relax.call_tir( + func_gv, [x1, x2, weight], - out_sinfo=[x1.struct_info, x1.struct_info], + out_sinfo=[x1.struct_info, x2.struct_info], ) - return { - o: relax.TupleGetItem(o_y_tuple, 0), - y: relax.TupleGetItem(o_y_tuple, 1), - } - - new_mod = {} - for gvar, func in mod.functions.items(): - if isinstance(func, relax.Function): - func = rewrite_bindings(ctx, rewriter, func) - new_mod[gvar] = func - - for gvar, func in mod.functions.items(): - if isinstance(func, tvm.tir.PrimFunc) and gvar not in new_mod: - new_mod[gvar] = func - - new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) - return new_mod + ) + new_o = relax.TupleGetItem(tuple_output, 0) + new_y = self.builder_.emit(relax.TupleGetItem(tuple_output, 1)) + self.set_var_remap(call.args[0].vid, new_y) + return new_o