Skip to content

[Bug] Int64 BroadCast-ArgMax triggers assertion error at graph runtime #11794

@ganler

Description

@ganler

The following miminized code snippet (broadcast-argmax) triggers error in graph runtime (i.e., can compile but runtime binary cannot run). At compile time, te/schedule/bound.cc:119 says "not in feed graph consumer" as warnings. Seems to be the bugs in codegen?

cc te's code owners: @junrushao1994 @vinx13 @masahi

import tvm
from tvm import relay
from tvm.relay import testing
import numpy as np

"""
def @main(%x: Tensor[(11, 41, 1, 1), float32] /* ty=Tensor[(11, 41, 1, 1), float32] */) -> Tensor[(11i64, 1i64, 1i64), int64] {
  %0 = broadcast_to(%x, shape=[11i64, 41i64, 1i64, 1i64]) /* ty=Tensor[(11i64, 41i64, 1i64, 1i64), float32] */;
  argmax(%0, axis=[1]) /* ty=Tensor[(11i64, 1i64, 1i64), int32] */;
}
"""
x_shape = (1, 1)
broadcast_shape = [1, 1]
x = relay.var("data", relay.TensorType(x_shape, "float32"))
broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype="int64"))
argmax = relay.op.argmax(broadcast_to, axis=[1])

f = relay.Function([x], argmax)

x = np.zeros(x_shape).astype("float32")
ref_res = np.broadcast_to(x, broadcast_shape).argmax(axis=1)

op_res = relay.create_executor(
    'graph', device=tvm.cpu(), target='llvm').evaluate(f)(x)
tvm.testing.assert_allclose(op_res.numpy(), ref_res)

Expected behavior

Accepts inputs at runtime.

Actual behavior

[22:47:22] /home/ganler/Documents/tvm-pr-2/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(T_broadcast_to_red_temp, body=[reduce(combiner=comm_reducer(result=[select(((argmax_lhs_1 > argmax_rhs_1) || ((argmax_lhs_1 == argmax_rhs_1) && (argmax_lhs_0 < argmax_rhs_0))), argmax_lhs_0, argmax_rhs_0), select((argmax_lhs_1 > argmax_rhs_1), argmax_lhs_1, argmax_rhs_1)], lhs=[argmax_lhs_0, argmax_lhs_1], rhs=[argmax_rhs_0, argmax_rhs_1], identity_element=[(int64)-1, -3.40282e+38f]), source=[k1, placeholder[ax0, k1]], init=[], axis=[iter_var(k1, range(min=0, ext=(int64)1))], where=(bool)1, value_index=0), reduce(combiner=comm_reducer(result=[select(((argmax_lhs_1 > argmax_rhs_1) || ((argmax_lhs_1 == argmax_rhs_1) && (argmax_lhs_0 < argmax_rhs_0))), argmax_lhs_0, argmax_rhs_0), select((argmax_lhs_1 > argmax_rhs_1), argmax_lhs_1, argmax_rhs_1)], lhs=[argmax_lhs_0, argmax_lhs_1], rhs=[argmax_rhs_0, argmax_rhs_1], identity_element=[(int64)-1, -3.40282e+38f]), source=[k1, placeholder[ax0, k1]], init=[], axis=[iter_var(k1, range(min=0, ext=(int64)1))], where=(bool)1, value_index=1)], axis=[iter_var(ax0, range(min=0, ext=(int64)1))], reduce_axis=[iter_var(k1, range(min=0, ext=(int64)1))], tag=comm_reduce_idx, attrs={})
Traceback (most recent call last):
  File "test.py", line 52, in <module>
    op_res = relay.create_executor(
  File "/home/ganler/Documents/tvm-pr-2/python/tvm/relay/build_module.py", line 589, in _graph_wrapper
    gmodule.run()
  File "/home/ganler/Documents/tvm-pr-2/python/tvm/contrib/graph_executor.py", line 208, in run
    self._run()
  File "/home/ganler/Documents/tvm-pr-2/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  13: TVMFuncCall
        at /home/ganler/Documents/tvm-pr-2/src/runtime/c_runtime_api.cc:477
  12: tvm::runtime::PackedFuncObj::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/ganler/Documents/tvm-pr-2/include/tvm/runtime/packed_func.h:1217
  11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::GraphExecutor::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_12> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
        at /home/ganler/Documents/tvm-pr-2/include/tvm/runtime/packed_func.h:1213
  10: tvm::runtime::GraphExecutor::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_12::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/ganler/Documents/tvm-pr-2/src/runtime/graph_executor/graph_executor.cc:599
  9: tvm::runtime::GraphExecutor::Run()
        at /home/ganler/Documents/tvm-pr-2/src/runtime/graph_executor/graph_executor.cc:62
  8: std::function<void ()>::operator()() const
        at /usr/bin/../lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/bits/std_function.h:622
  7: std::_Function_handler<void (), tvm::runtime::GraphExecutor::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::vector<DLTensor, std::allocator<DLTensor> > const&)::$_2>::_M_invoke(std::_Any_data const&)
        at /usr/bin/../lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/bits/std_function.h:291
  6: _ZSt10__invoke_rIvRZN3tvm7runtime13GraphExecutor11CreateTVMOpERKNS1_10TVMOpParamERKSt6vectorI8DLTensorSaIS7_EEE3$_2JEENSt9enable_ifIXsr6__and_ISt7is_voidIT_ESt14__is_invocableIT0_JDpT1_EEEE5valueESG_E4typeEOSJ_DpOSK_
        at /usr/bin/../lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/bits/invoke.h:153
  5: void std::__invoke_impl<void, tvm::runtime::GraphExecutor::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::vector<DLTensor, std::allocator<DLTensor> > const&)::$_2&>(std::__invoke_other, tvm::runtime::GraphExecutor::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::vector<DLTensor, std::allocator<DLTensor> > const&)::$_2&)
        at /usr/bin/../lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/bits/invoke.h:60
  4: tvm::runtime::GraphExecutor::CreateTVMOp(tvm::runtime::TVMOpParam const&, std::vector<DLTensor, std::allocator<DLTensor> > const&)::$_2::operator()() const
        at /home/ganler/Documents/tvm-pr-2/src/runtime/graph_executor/graph_executor.cc:537
  3: tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/ganler/Documents/tvm-pr-2/include/tvm/runtime/packed_func.h:1221
  2: tvm::runtime::PackedFuncObj::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/ganler/Documents/tvm-pr-2/include/tvm/runtime/packed_func.h:1217
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*, void*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
        at /home/ganler/Documents/tvm-pr-2/include/tvm/runtime/packed_func.h:1213
  0: tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*, void*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/ganler/Documents/tvm-pr-2/src/runtime/library_module.cc:80
  File "/home/ganler/Documents/tvm-pr-2/src/runtime/library_module.cc", line 80
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: ret == 0 (-1 vs. 0) : Assert fail: (((tir.tvm_struct_get(arg.T_broadcast_to_red, 0, 5) == (uint8)0) && (tir.tvm_struct_get(arg.T_broadcast_to_red, 0, 6) == (uint8)64)) && (tir.tvm_struct_get(arg.T_broadcast_to_red, 0, 7) == (uint16)1)), arg.T_broadcast_to_red.dtype is expected to be int64

Environment

Ubuntu 20.04. commit tag: 9bba758

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions