Skip to content

[Bug] [Relax] Inconsistent buffers compute and lv mapped to the same relax var #17254

@Cookiee235

Description

@Cookiee235

I came across the unexpected crash when executing the following scripts. This crash can only be triggered by a consequence of multiple transforms, i.e., [ToMixedPrecision, LegalizeOps, AnnotateTIROpPattern, FuseOps, FuseTIR]. Removing any pass will result in the bug being unable to reproduce again.

Actual behavior

Traceback (most recent call last):
  File "test.py", line 92, in <module>
    mod = relax.transform.FuseTIR()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  22: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  21: tvm::transform::Pass::operator()(tvm::IRModule) const
  20: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  17: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  16: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform7FuseTIREvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  15: tvm::relax::FuseTIR(tvm::IRModule)
  14: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
  13: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&, tvm::GlobalVar const&)
  12: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  11: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode const*)
  10: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  9: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
  8: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
  7: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
  6: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
  5: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode const*)
  4: _ZN3tvm5relax11ExprVisitor13VisitBinding_EPKNS0_14Va
  3: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  2: tvm::relax::RelaxToTIRVarMapCollector::VisitExpr_(tvm::relax::CallNode const*)
  1: tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode const*, tvm::RelayExpr const&, bool)
  0: tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode const*, tvm::RelayExpr const&, bool)::{lambda(tvm::tir::Buffer, tvm::RelayExpr)#1}::operator()(tvm::tir::Buffer, tvm::RelayExpr) const
  File "/software/tvm-lunder/src/relax/transform/fuse_tir.cc", line 442
InternalError: Check failed: (StructuralEqual()((*it).second, new_buf)) is false: Inconsistent buffers compute and lv mapped to the same relax var: lv11

Steps to reproduce

Full test script
import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


@I.ir_module
class Module:
    @T.prim_func(private=True)
    def conv2d2(data: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), weight1: T.Buffer((T.int64(16), T.int64(3), T.int64(3), T.int64(16)), "float16"), conv2d_nhwc: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((T.int64(16), T.int64(34), T.int64(34), T.int64(16)), "float16")
        for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(34), T.int64(34), T.int64(16)):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(33) and T.int64(1) <= v_i2 and v_i2 < T.int64(33), data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], T.float16(0))
        for nn, yy, xx, ff, ry, rx, rc in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16), T.int64(3), T.int64(3), T.int64(16)):
            with T.block("conv2d_nhwc"):
                v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc = T.axis.remap("SSSSRRR", [nn, yy, xx, ff, ry, rx, rc])
                T.reads(pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc], weight1[v_ff, v_ry, v_rx, v_rc])
                T.writes(conv2d_nhwc[v_nn, v_yy, v_xx, v_ff])
                with T.init():
                    conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = T.float16(0)
                conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] + pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc] * weight1[v_ff, v_ry, v_rx, v_rc]

    @T.prim_func(private=True)
    def layer_norm(conv2: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), gamma: T.Buffer((T.int64(16),), "float16"), beta: T.Buffer((T.int64(16),), "float16"), T_layer_norm: T.Buffer((T.int64(16), T.int64(32),T.int64(32), T.int64(16)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        conv2_red_temp_v0 = T.alloc_buffer((T.int64(16), T.int64(32), T.int64(32)))
        conv2_red_temp_v1 = T.alloc_buffer((T.int64(16), T.int64(32), T.int64(32)))
        for ax0, ax1, ax2, k3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)):
            with T.block("conv2_red_temp"):
                v_ax0, v_ax1, v_ax2, v_k3 = T.axis.remap("SSSR", [ax0, ax1, ax2, k3])
                T.reads(conv2[v_ax0, v_ax1, v_ax2, v_k3])
                T.writes(conv2_red_temp_v0[v_ax0, v_ax1, v_ax2], conv2_red_temp_v1[v_ax0, v_ax1, v_ax2])
                with T.init():
                    conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0)
                    conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0)
                v_conv2_red_temp_v0: T.float32 = conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] + T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3])
                v_conv2_red_temp_v1: T.float32 = conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] + T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3]) * T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3])
                conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_conv2_red_temp_v0
                conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_conv2_red_temp_v1
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)):
            with T.block("T_layer_norm"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(conv2[v_ax0, v_ax1, v_ax2, v_ax3], conv2_red_temp_v0[v_ax0, v_ax1, v_ax2], conv2_red_temp_v1[v_ax0, v_ax1, v_ax2], gamma[v_ax3], beta[v_ax3])
                T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
                T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", (T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_ax3]) - conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) * T.rsqrt(conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) - conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) * (conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) + T.float32(1.0000000000000001e-05))) * gamma[v_ax3] + beta[v_ax3]

    @T.prim_func(private=True)
    def relu(lv: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)):
            with T.block("compute"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(lv[v_i0, v_i1, v_i2, v_i3])
                T.writes(compute[v_i0, v_i1, v_i2, v_i3])
                compute[v_i0, v_i1, v_i2, v_i3] = T.max(lv[v_i0, v_i1, v_i2, v_i3], T.float16(0))

    @R.function
    def main(data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), weight3: R.Tensor((16, 3, 3, 16), dtype="float16"), gamma: R.Tensor((16,), dtype="float16"), beta: R.Tensor((16,), dtype="float16")) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.conv2d2, (data, weight1), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
            conv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
            lv2 = R.call_tir(cls.conv2d2, (conv1, weight2), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
            conv2 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
            ln = R.call_tir(cls.layer_norm, (conv2, gamma, beta), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
            lv5 = R.call_tir(cls.conv2d2, (ln, weight3), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
            conv3 = R.call_tir(cls.relu, (lv5,), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
            R.output(conv3)
        return conv3


mod = Module

mod = relax.transform.ToMixedPrecision()(mod)
mod = relax.transform.LegalizeOps()(mod)
mod = relax.transform.AnnotateTIROpPattern()(mod)
mod = relax.transform.FuseOps()(mod)
mod = relax.transform.FuseTIR()(mod)

cc @Lunderberg @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions