Skip to content

[ONNX] ssd-mobilenetv1 fail to build #8284

@kevinLu1114

Description

@kevinLu1114

The discussion I saw is at https://discuss.tvm.apache.org/t/failures-using-many-of-onnx-model-zoo-models/10268

I used a script like https://gist.github.com/masahi/9348db919edb105912b94b84792dd7d3 to build ssd-mobilenetv1, but some errors appeared.

tvm branch (commit 1fac10b)
llvm version; 12.0.1
OS info: Ubuntu 20.10 (Groovy Gorilla)

error message:

==> https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.tar.gz <==
Loading ssd_mobilenet_v1/ssd_mobilenet_v1.onnx ...
Input shapes: {'image_tensor:0': (1, 383, 640, 3)}
Importing graph from ONNX to TVM Relay IR ...
/home/chlu/tvm/python/tvm/relay/frontend/onnx.py:2572: UserWarning: 
                Using scan outputs in a loop with strided slice
                currently may cause errors during compilation.
                
  warnings.warn(
[14:48:48] ../src/runtime/threading_backend.cc:217: Warning: more than two frequencies detected!
Compiling graph from Relay IR to llvm ...
Caught an exception Traceback (most recent call last):
  37: TVMFuncCall
  36: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  35: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Map<tvm::Integer, tvm::Target, void, void> const&, tvm::Target const&)
  34: tvm::relay::vm::VMCompiler::OptimizeModule(tvm::IRModule, tvm::runtime::Map<tvm::Integer, tvm::Target, void, void> const&, tvm::Target const&)
  33: tvm::transform::Pass::operator()(tvm::IRModule) const
  32: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  31: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  30: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  29: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  28: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::AlterOpLayout()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::AlterOpLayout()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  27: tvm::relay::alter_op_layout::AlterOpLayout(tvm::RelayExpr const&)
  26: tvm::relay::ForwardRewrite(tvm::RelayExpr const&, tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)> const&, std::function<tvm::runtime::ObjectRef (tvm::relay::Call const&)>, std::function<tvm::RelayExpr (tvm::RelayExpr const&)>)
  25: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  24: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  23: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9RelayExp
  22: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  21: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  20: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
  19: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  18: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  17: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  16: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9RelayExp
  15: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  14: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  13: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
  12: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::LetNode const*)
  11: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  10: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  9: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9RelayExp
  8: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  7: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  6: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
  5: tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)
  4: tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)
  3: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  2: tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)
  1: tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::vector<tvm::RelayExpr, std::allocator<tvm::RelayExpr> > const&)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) [clone .cold]
  File "/home/chlu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/chlu/tvm/python/tvm/relay/op/nn/_nn.py", line 195, in alter_op_layout_conv2d
    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
  File "<decorator-gen-58>", line 2, in conv2d_alter_layout
  File "/home/chlu/tvm/python/tvm/target/generic_func.py", line 275, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/home/chlu/tvm/python/tvm/topi/x86/conv2d_alter_op.py", line 60, in _alter_conv2d_layout
    impl, outs = relay.backend.compile_engine.select_implementation(
  File "/home/chlu/tvm/python/tvm/relay/backend/compile_engine.py", line 219, in select_implementation
    outs = impl.compute(attrs, inputs, out_type)
  File "/home/chlu/tvm/python/tvm/relay/op/op.py", line 125, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/home/chlu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::__mk_TVM6::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) [clone .cold]
  File "/home/chlu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/chlu/tvm/python/tvm/relay/op/strategy/generic.py", line 240, in _compute_conv2d
    return [topi_compute(*args)]
  File "/home/chlu/tvm/python/tvm/topi/x86/conv2d.py", line 129, in conv2d_nchw
    packed_out = conv2d_NCHWc(data, kernel, strides, padding, dilation, layout, layout, out_dtype)
  File "/home/chlu/tvm/python/tvm/autotvm/task/topi_integration.py", line 165, in wrapper
    node = topi_compute(cfg, *args)
  File "/home/chlu/tvm/python/tvm/topi/x86/conv2d.py", line 191, in conv2d_NCHWc
    oh = (ih - kernel_height + pt + pb) // sh + 1
TypeError: unsupported operand type(s) for -: 'Any' and 'int'

script

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# See:
# - https://tvm.apache.org/docs/tutorials/frontend/from_onnx.html
# - https://github.com/apache/tvm/blob/main/tutorials/frontend/from_onnx.py
# - https://github.com/onnx/models


import subprocess
import os
import sys
import posixpath
from six.moves.urllib.request import urlretrieve
import glob

import onnx
from onnx import numpy_helper
import numpy as np
import tvm
import tvm.relay as relay
from tvm.contrib import graph_executor
from tvm.runtime.vm import VirtualMachine


def get_value_info_shape(value_info):
    return tuple([max(d.dim_value, 1) for d in value_info.type.tensor_type.shape.dim])

urls = [
    'https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.tar.gz',
]

target = "cuda"

ctx = tvm.device(target, 0)

summary = []
for url in urls:
    print(f'==> {url} <==')

    archive = posixpath.basename(url)
    if not os.path.exists(archive):
        print(f'Downloading {url} ...')
        urlretrieve(url, archive)
        assert os.path.exists(archive)

    import tarfile
    tar = tarfile.open(archive, 'r:gz')
    for n in tar.getnames():
        if n.endswith('.onnx'):
            model_file = n
            name = os.path.dirname(n)
            break

    if not os.path.exists(model_file):
        print(f'Extracting {archive} ...')
        #subprocess.call(['tar', 'xzf', archive])
        tar.extractall()
        assert os.path.exists(model_file)

    print(f'Loading {model_file} ...')
    onnx_model = onnx.load(model_file)

    graph = onnx_model.graph

    initializers = set()
    for initializer in graph.initializer:
        initializers.add(initializer.name)

    input_values = []

    test_data_set = glob.glob(os.path.join(name, 'test_data_set_*'))[0]
    shape_dict = {}
    assert os.path.exists(test_data_set)
    inputs = {}
    for input in graph.input:
        if input.name not in initializers:
            i = len(input_values)
            input_data = os.path.join(test_data_set, f'input_{i}.pb')
            tensor = onnx.TensorProto()
            input_data = open(input_data, 'rb').read()
            tensor.ParseFromString(input_data)
            x = numpy_helper.to_array(tensor)
            input_values.append(x)
            shape_dict[input.name] = x.shape
            inputs[input.name] = tvm.nd.array(x, ctx)

    print(f'Input shapes: {shape_dict}')

    try:
        print(f'Importing graph from ONNX to TVM Relay IR ...')
        mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
        mod = relay.transform.DynamicToStatic()(mod)

        print(f'Compiling graph from Relay IR to {target} ...')
        with tvm.transform.PassContext(opt_level=3):
            vm_exec = relay.vm.compile(mod, target, params=params)

        dev = tvm.device(target, 0)
        vm = VirtualMachine(vm_exec, dev)
        vm.set_input("main", **inputs)

        print(f"Running inference...")
        vm.run()
    except KeyboardInterrupt:
        raise
    except Exception as ex:
        print(f'Caught an exception {ex}')
        result = 'not ok'
    else:
        print(f'Succeeded!')
        result = 'ok'
    summary.append((result, url))
    print()

print('Summary:')
for result, url in summary:
    print(f'{result}\t- {url}')

Metadata

Metadata

Assignees

No one assigned

    Labels

    frontend:onnxpython/tvm/relay/frontend/onnx.pyrelay:opsrc/relay/op

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions