Skip to content

[Bug] TIR issued buffer name when Applying FuseTIR with duplicated block. #16433

@LeiWang1999

Description

@LeiWang1999

The code to reproduce:

import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I, relax as R, tir as T
from tvm import tir
from tvm.ir import IRModule
from tvm.ir.transform import PassContext, module_pass


@I.ir_module
class Before:
    @T.prim_func(private=True)
    def add(
        A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
        Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
    ):
        for i, j in T.grid(T.int64(4096), T.int64(4096)):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                Out[vi, vj] = A[vi, vj] + T.float16(1.0)

    @T.prim_func(private=True)
    def take(
        A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
        B: T.Buffer((T.int64(1),), "int32"),
        T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
    ):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
            with T.block("T_take"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]

    @T.prim_func(private=True)
    def add1(
        A: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
        Out: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
    ):
        for i, j in T.grid(T.int64(1), T.int64(4096)):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                Out[vi, vj] = A[vi, vj] + T.float16(2.0)

    @R.function
    def main(
        input_ids: R.Tensor((1,), dtype="int32"),
        input_embeds: R.Tensor((4096, 4096), dtype="float16"),
    ) -> R.Tensor((1, 4096), dtype="float16"):
        cls = Before
        with R.dataflow():
            gv: R.Tensor((1, 4096), dtype="float16") = cls.fused_func(input_ids, input_embeds)
            R.output(gv)
        return gv

    @R.function(private=True)
    def fused_func(
        input_ids: R.Tensor((1,), dtype="int32"),
        input_embeds: R.Tensor((4096, 4096), dtype="float16"),
    ) -> R.Tensor((1, 4096), dtype="float16"):
        R.func_attr({"Primitive": 1})
        cls = Before
        with R.dataflow():
            lv = R.call_tir(
                cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
            )
            lv1 = R.call_tir(
                cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16")
            )
            gv = R.call_tir(cls.add1, (lv1,), out_sinfo=R.Tensor((1, 4096), dtype="float16"))
            R.output(gv)
        return gv


relax_mod = Before
relax_mod = relax.transform.FuseTIR()(relax_mod)
print(relax_mod)


@module_pass(opt_level=0, name="ApplyFastTuning")
class Traverse_mod:  # pylint: disable=too-few-public-methods
    """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module."""

    def __init__(
        self,
    ):
        pass

    def transform_module(  # pylint: disable=missing-function-docstring
        self,
        mod: IRModule,
        _: PassContext,
    ) -> IRModule:
        for g_var, func in mod.functions_items():
            if isinstance(func, tir.PrimFunc):
                sch = tvm.tir.Schedule(func)
                print("implement sch.get_sref(sch.get_block('add')).stmt")
                print(sch.get_sref(sch.get_block("add")).stmt)
                print("implement sch.get_sref(sch.get_block('add_1')).stmt")
                print(sch.get_sref(sch.get_block("add_1")).stmt)
        return mod


relax_mod = Traverse_mod()(relax_mod)

From the result we can observe, even though the printed relax mod is correct (the block add and the block add1 use different buffers). The blocks' srefs' buffers still the same.

    @T.prim_func(private=True)
    def fused_func(input_ids: T.Buffer((T.int64(1),), "int32"), input_embeds: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), Out_handle_intermediate: T.Buffer((T.int64(1), T.int64(4096)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        Out_handle_intermediate_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
        T_take_handle_intermediate = T.alloc_buffer((T.int64(1), T.int64(4096)), "float16")
        for i, j in T.grid(T.int64(4096), T.int64(4096)):
            with T.block("add"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(input_embeds[vi, vj])
                T.writes(Out_handle_intermediate_1[vi, vj])
                Out_handle_intermediate_1[vi, vj] = input_embeds[vi, vj] + T.float16(1)
        for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
            with T.block("T_take"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(Out_handle_intermediate_1[input_ids[v_ax0], v_ax1], input_ids[v_ax0])
                T.writes(T_take_handle_intermediate[v_ax0, v_ax1])
                T_take_handle_intermediate[v_ax0, v_ax1] = Out_handle_intermediate_1[input_ids[v_ax0], v_ax1]
        for i, j in T.grid(T.int64(1), T.int64(4096)):
            with T.block("add_1"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(T_take_handle_intermediate[vi, vj])
                T.writes(Out_handle_intermediate[vi, vj])
                Out_handle_intermediate[vi, vj] = T_take_handle_intermediate[vi, vj] + T.float16(2)


implement sch.get_sref(sch.get_block('add')).stmt
with T.block("add", no_realize=True):
    vi = T.axis.spatial(T.int64(4096))
    vj = T.axis.spatial(T.int64(4096))
    input_embeds = T.Buffer((T.int64(4096), T.int64(4096)), "float16")
    T.reads(input_embeds[vi, vj])
    Out_handle_intermediate = T.Buffer((T.int64(4096), T.int64(4096)), "float16")
    T.writes(Out_handle_intermediate[vi, vj])
    Out_handle_intermediate[vi, vj] = input_embeds[vi, vj] + T.float16(1)

implement sch.get_sref(sch.get_block('add_1')).stmt
with T.block("add_1", no_realize=True):
    vi = T.axis.spatial(T.int64(1))
    vj = T.axis.spatial(T.int64(4096))
    T_take_handle_intermediate = T.Buffer((T.int64(1), T.int64(4096)), "float16")
    T.reads(T_take_handle_intermediate[vi, vj])
    Out_handle_intermediate = T.Buffer((T.int64(1), T.int64(4096)), "float16")
    T.writes(Out_handle_intermediate[vi, vj])
    Out_handle_intermediate[vi, vj] = T_take_handle_intermediate[vi, vj] + T.float16(2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions