-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
I came across the unexpected crash when executing the following scripts. This crash can only be triggered by a consequence of multiple transforms, i.e., [ToMixedPrecision, LegalizeOps, AnnotateTIROpPattern, FuseOps, FuseTIR]. Removing any pass will result in the bug being unable to reproduce again.
Actual behavior
Traceback (most recent call last):
File "test.py", line 92, in <module>
mod = relax.transform.FuseTIR()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
raise_last_ffi_error()
File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
22: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
21: tvm::transform::Pass::operator()(tvm::IRModule) const
20: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
19: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
17: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
16: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform7FuseTIREvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
15: tvm::relax::FuseTIR(tvm::IRModule)
14: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
13: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&, tvm::GlobalVar const&)
12: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
11: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode const*)
10: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
9: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
8: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
7: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
6: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
5: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode const*)
4: _ZN3tvm5relax11ExprVisitor13VisitBinding_EPKNS0_14Va
3: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
2: tvm::relax::RelaxToTIRVarMapCollector::VisitExpr_(tvm::relax::CallNode const*)
1: tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode const*, tvm::RelayExpr const&, bool)
0: tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode const*, tvm::RelayExpr const&, bool)::{lambda(tvm::tir::Buffer, tvm::RelayExpr)#1}::operator()(tvm::tir::Buffer, tvm::RelayExpr) const
File "/software/tvm-lunder/src/relax/transform/fuse_tir.cc", line 442
InternalError: Check failed: (StructuralEqual()((*it).second, new_buf)) is false: Inconsistent buffers compute and lv mapped to the same relax var: lv11
Steps to reproduce
Full test script
import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def conv2d2(data: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), weight1: T.Buffer((T.int64(16), T.int64(3), T.int64(3), T.int64(16)), "float16"), conv2d_nhwc: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(16), T.int64(34), T.int64(34), T.int64(16)), "float16")
for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(34), T.int64(34), T.int64(16)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(33) and T.int64(1) <= v_i2 and v_i2 < T.int64(33), data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], T.float16(0))
for nn, yy, xx, ff, ry, rx, rc in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16), T.int64(3), T.int64(3), T.int64(16)):
with T.block("conv2d_nhwc"):
v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc = T.axis.remap("SSSSRRR", [nn, yy, xx, ff, ry, rx, rc])
T.reads(pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc], weight1[v_ff, v_ry, v_rx, v_rc])
T.writes(conv2d_nhwc[v_nn, v_yy, v_xx, v_ff])
with T.init():
conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = T.float16(0)
conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] = conv2d_nhwc[v_nn, v_yy, v_xx, v_ff] + pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc] * weight1[v_ff, v_ry, v_rx, v_rc]
@T.prim_func(private=True)
def layer_norm(conv2: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), gamma: T.Buffer((T.int64(16),), "float16"), beta: T.Buffer((T.int64(16),), "float16"), T_layer_norm: T.Buffer((T.int64(16), T.int64(32),T.int64(32), T.int64(16)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
conv2_red_temp_v0 = T.alloc_buffer((T.int64(16), T.int64(32), T.int64(32)))
conv2_red_temp_v1 = T.alloc_buffer((T.int64(16), T.int64(32), T.int64(32)))
for ax0, ax1, ax2, k3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)):
with T.block("conv2_red_temp"):
v_ax0, v_ax1, v_ax2, v_k3 = T.axis.remap("SSSR", [ax0, ax1, ax2, k3])
T.reads(conv2[v_ax0, v_ax1, v_ax2, v_k3])
T.writes(conv2_red_temp_v0[v_ax0, v_ax1, v_ax2], conv2_red_temp_v1[v_ax0, v_ax1, v_ax2])
with T.init():
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0)
conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0)
v_conv2_red_temp_v0: T.float32 = conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] + T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3])
v_conv2_red_temp_v1: T.float32 = conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] + T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3]) * T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_k3])
conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_conv2_red_temp_v0
conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_conv2_red_temp_v1
for ax0, ax1, ax2, ax3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)):
with T.block("T_layer_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(conv2[v_ax0, v_ax1, v_ax2, v_ax3], conv2_red_temp_v0[v_ax0, v_ax1, v_ax2], conv2_red_temp_v1[v_ax0, v_ax1, v_ax2], gamma[v_ax3], beta[v_ax3])
T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", (T.Cast("float32", conv2[v_ax0, v_ax1, v_ax2, v_ax3]) - conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) * T.rsqrt(conv2_red_temp_v1[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) - conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625) * (conv2_red_temp_v0[v_ax0, v_ax1, v_ax2] * T.float32(0.0625)) + T.float32(1.0000000000000001e-05))) * gamma[v_ax3] + beta[v_ax3]
@T.prim_func(private=True)
def relu(lv: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(32), T.int64(32), T.int64(16)):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(lv[v_i0, v_i1, v_i2, v_i3])
T.writes(compute[v_i0, v_i1, v_i2, v_i3])
compute[v_i0, v_i1, v_i2, v_i3] = T.max(lv[v_i0, v_i1, v_i2, v_i3], T.float16(0))
@R.function
def main(data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), weight3: R.Tensor((16, 3, 3, 16), dtype="float16"), gamma: R.Tensor((16,), dtype="float16"), beta: R.Tensor((16,), dtype="float16")) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.conv2d2, (data, weight1), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
conv1 = R.call_tir(cls.relu, (lv,), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
lv2 = R.call_tir(cls.conv2d2, (conv1, weight2), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
conv2 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
ln = R.call_tir(cls.layer_norm, (conv2, gamma, beta), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
lv5 = R.call_tir(cls.conv2d2, (ln, weight3), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
conv3 = R.call_tir(cls.relu, (lv5,), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"))
R.output(conv3)
return conv3
mod = Module
mod = relax.transform.ToMixedPrecision()(mod)
mod = relax.transform.LegalizeOps()(mod)
mod = relax.transform.AnnotateTIROpPattern()(mod)
mod = relax.transform.FuseOps()(mod)
mod = relax.transform.FuseTIR()(mod)Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug