-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Description
When using using optimization level >= 3 - LayoutTransform fails on converting C -> NCHW8c when input shape is (1,) and that breaks next operation which suppose to broadcast this input.
Code to reproduce. It works with opt_level < 3
target = 'llvm'
target_host = 'llvm'
input = relay.var('input', shape=(1, 500, 500, 64), dtype='float32')
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var('bias', shape=(64,), dtype='float32')
multiplier = relay.const([0.1])
x = relay.nn.conv2d(input, kernel, data_layout='NHWC', kernel_layout="HWIO", kernel_size=(3, 3))
x = relay.add(bias, x)
x = relay.nn.relu(x)
x = relay.multiply(multiplier, x)
fun = relay.Function([input, kernel, bias], x)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(relay.Module({"main": fun}),
target=target,
target_host=target_host)
Compiled program with issue
fn (%input: Tensor[(1, 500, 500, 64), float32], %kernel: Tensor[(3, 3, 64, 64), float32], %bias: Tensor[(64), float32]) -> Tensor[(1, 498, 498, 64), float32] {
%0 = expand_dims(meta[relay.Constant][0], axis=0, num_newaxis=3);
%1 = layout_transform(%0, src_layout="NHWC", dst_layout="NCHW8c");
%2 = expand_dims(%bias, axis=0, num_newaxis=3);
%3 = layout_transform(%2, src_layout="NHWC", dst_layout="NCHW8c");
%4 = layout_transform(%input, src_layout="NHWC", dst_layout="NCHW8c");
%5 = layout_transform(%kernel, src_layout="HWIO", dst_layout="OIHW8i8o");
%6 = nn.contrib_conv2d_NCHWc(%4, %5, channels=64, kernel_size=[3, 3], data_layout="NCHW8c", kernel_layout="OIHW8i8o", out_layout="NCHW8c");
%7 = add(%3, %6);
%8 = nn.relu(%7);
%9 = multiply(%1, %8) Incompatible broadcast type TensorType([1, 0, 1, 1, 8], float32) and TensorType([1, 8, 498, 498, 8], float32); ;
layout_transform(%9, src_layout="NCHW8c", dst_layout="NHWC") an internal invariant was violated while typechecking your program [18:29:24] /incubator-tvm/src/relay/op/tensor/transform.cc:2382: Check failed: data != nullptr:
;
}
Full Stack Trace
an internal invariant was violated while typechecking your program [18:29:24] /incubator-tvm/src/relay/op/tensor/transform.cc:2382: Check failed: data != nullptr:
Stack trace:
[bt] (0) 1 libtvm.dylib 0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
[bt] (1) 2 libtvm.dylib 0x0000000110f0b198 tvm::relay::LayoutTransformRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) + 280
[bt] (2) 3 libtvm.dylib 0x0000000110d6676f void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 95
[bt] (3) 4 libtvm.dylib 0x0000000110d666c9 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 137
[bt] (4) 5 libtvm.dylib 0x00000001110b55b5 tvm::TypedEnvFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::operator()(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) const + 325
[bt] (5) 6 libtvm.dylib 0x00000001110b4f2f tvm::relay::TypeSolver::Solve() + 1071
[bt] (6) 7 libtvm.dylib 0x000000011109962c tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 108
[bt] (7) 8 libtvm.dylib 0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
[bt] (8) 9 libtvm.dylib 0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576
; ' should not has tab or newline.
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) 9 libtvm.dylib 0x0000000110fdfa06 tvm::relay::transform::PassNode::operator()(tvm::relay::Module const&) const + 54
[bt] (7) 8 libtvm.dylib 0x0000000111052f86 tvm::relay::transform::SequentialNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 838
[bt] (6) 7 libtvm.dylib 0x000000011105338c tvm::relay::transform::Pass::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 156
[bt] (5) 6 libtvm.dylib 0x0000000111051b97 tvm::relay::transform::FunctionPassNode::operator()(tvm::relay::Module const&, tvm::relay::transform::PassContext const&) const + 1607
[bt] (4) 5 libtvm.dylib 0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576
[bt] (3) 4 libtvm.dylib 0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
[bt] (2) 3 libtvm.dylib 0x0000000111099648 tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 136
[bt] (1) 2 libtvm.dylib 0x0000000111161769 tvm::relay::ErrorReporter::RenderErrors(tvm::relay::Module const&, bool) + 5433
[bt] (0) 1 libtvm.dylib 0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
[bt] (8) 9 libtvm.dylib 0x0000000111184c18 tvm::relay::ModuleNode::Add(tvm::relay::GlobalVar const&, tvm::relay::Function const&, bool) + 1576
[bt] (7) 8 libtvm.dylib 0x000000011109a482 tvm::relay::InferType(tvm::relay::Function const&, tvm::relay::Module const&, tvm::relay::GlobalVar const&) + 546
[bt] (6) 7 libtvm.dylib 0x000000011109962c tvm::relay::TypeInferencer::Infer(tvm::relay::Expr) + 108
[bt] (5) 6 libtvm.dylib 0x00000001110b4f2f tvm::relay::TypeSolver::Solve() + 1071
[bt] (4) 5 libtvm.dylib 0x00000001110b55b5 tvm::TypedEnvFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::operator()(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) const + 325
[bt] (3) 4 libtvm.dylib 0x0000000110d666c9 std::__1::__function::__func<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*), std::__1::allocator<void tvm::runtime::TypedPackedFunc<bool (tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>(bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 137
[bt] (2) 3 libtvm.dylib 0x0000000110d6676f void tvm::runtime::detail::unpack_call_dispatcher<bool, 0, 4, bool (*)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&)>::run<tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue, tvm::runtime::TVMArgValue>(bool (* const&)(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&, tvm::runtime::TVMArgValue&&) + 95
[bt] (1) 2 libtvm.dylib 0x0000000110f0b198 tvm::relay::LayoutTransformRel(tvm::Array<tvm::relay::Type, void> const&, int, tvm::Attrs const&, tvm::relay::TypeReporter const&) + 280
[bt] (0) 1 libtvm.dylib 0x00000001109b3aa9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
File "incubator-tvm/src/relay/ir/error.cc", line 132
Metadata
Metadata
Assignees
Labels
No labels