Skip to content

Broadcasting is broken with LayoutTransform #4508

@pyalex

Description

@pyalex

When using using optimization level >= 3 - LayoutTransform fails on converting C -> NCHW8c when input shape is (1,) and that breaks next operation which suppose to broadcast this input.

Code to reproduce. It works with opt_level < 3

target = 'llvm'
target_host = 'llvm'

input = relay.var('input', shape=(1, 500, 500, 64), dtype='float32')
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var('bias', shape=(64,), dtype='float32')

multiplier = relay.const([0.1])

x = relay.nn.conv2d(input, kernel, data_layout='NHWC', kernel_layout="HWIO", kernel_size=(3, 3))
x = relay.add(bias, x)
x = relay.nn.relu(x)

x = relay.multiply(multiplier, x)

fun = relay.Function([input, kernel, bias], x)

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(relay.Module({"main": fun}),
                                     target=target,
                                     target_host=target_host)

Compiled program with issue

fn (%input: Tensor[(1, 500, 500, 64), float32], %kernel: Tensor[(3, 3, 64, 64), float32], %bias: Tensor[(64), float32]) -> Tensor[(1, 498, 498, 64), float32] {
  %0 = expand_dims(meta[relay.Constant][0], axis=0, num_newaxis=3);
  %1 = layout_transform(%0, src_layout="NHWC", dst_layout="NCHW8c");
  %2 = expand_dims(%bias, axis=0, num_newaxis=3);
  %3 = layout_transform(%2, src_layout="NHWC", dst_layout="NCHW8c");
  %4 = layout_transform(%input, src_layout="NHWC", dst_layout="NCHW8c");
  %5 = layout_transform(%kernel, src_layout="HWIO", dst_layout="OIHW8i8o");
  %6 = nn.contrib_conv2d_NCHWc(%4, %5, channels=64, kernel_size=[3, 3], data_layout="NCHW8c", kernel_layout="OIHW8i8o", out_layout="NCHW8c");
  %7 = add(%3, %6);
  %8 = nn.relu(%7);
  %9 = multiply(%1, %8) Incompatible broadcast type TensorType([1, 0, 1, 1, 8], float32) and TensorType([1, 8, 498, 498, 8], float32); ;
  layout_transform(%9, src_layout="NCHW8c", dst_layout="NHWC") an internal invariant was violated while typechecking your program [18:29:24] /incubator-tvm/src/relay/op/tensor/transform.cc:2382: Check failed: data != nullptr: 
; 
}

Full Stack Trace

an internal invariant was violated while typechecking your program [18:29:24] /incubator-tvm/src/relay/op/tensor/transform.cc:2382: Check failed: data != nullptr: 
Stack trace:
  [bt] (0) 1   libtvm.dylib                        0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  [bt] (1) 2   libtvm.dylib                        0x0000000110f0b198 tvm::relay::LayoutTransformRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) + 280
  [bt] (2) 3   libtvm.dylib                        0x0000000110d6676f void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 95
  [bt] (3) 4   libtvm.dylib                        0x0000000110d666c9 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 137
  [bt] (4) 5   libtvm.dylib                        0x00000001110b55b5 tvm::TypedEnvFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::operator()(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) const + 325
  [bt] (5) 6   libtvm.dylib                        0x00000001110b4f2f tvm::relay::TypeSolver::Solve() + 1071
  [bt] (6) 7   libtvm.dylib                        0x000000011109962c tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 108
  [bt] (7) 8   libtvm.dylib                        0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
  [bt] (8) 9   libtvm.dylib                        0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576

; ' should not has tab or newline.

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000110fdfa06 tvm::relay::transform::PassNode::operator()(tvm::relay::Module const&) const + 54
  [bt] (7) 8   libtvm.dylib                        0x0000000111052f86 tvm::relay::transform::SequentialNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 838
  [bt] (6) 7   libtvm.dylib                        0x000000011105338c tvm::relay::transform::Pass::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 156
  [bt] (5) 6   libtvm.dylib                        0x0000000111051b97 tvm::relay::transform::FunctionPassNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 1607
  [bt] (4) 5   libtvm.dylib                        0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576
  [bt] (3) 4   libtvm.dylib                        0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
  [bt] (2) 3   libtvm.dylib                        0x0000000111099648 tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 136
  [bt] (1) 2   libtvm.dylib                        0x0000000111161769 tvm::relay::ErrorReporter::RenderErrors(tvm::relay::Module const&, bool) + 5433
  [bt] (0) 1   libtvm.dylib                        0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  [bt] (8) 9   libtvm.dylib                        0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576
  [bt] (7) 8   libtvm.dylib                        0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
  [bt] (6) 7   libtvm.dylib                        0x000000011109962c tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 108
  [bt] (5) 6   libtvm.dylib                        0x00000001110b4f2f tvm::relay::TypeSolver::Solve() + 1071
  [bt] (4) 5   libtvm.dylib                        0x00000001110b55b5 tvm::TypedEnvFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::operator()(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) const + 325
  [bt] (3) 4   libtvm.dylib                        0x0000000110d666c9 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 137
  [bt] (2) 3   libtvm.dylib                        0x0000000110d6676f void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 95
  [bt] (1) 2   libtvm.dylib                        0x0000000110f0b198 tvm::relay::LayoutTransformRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) + 280
  [bt] (0) 1   libtvm.dylib                        0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "incubator-tvm/src/relay/ir/error.cc", line 132

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions