Skip to content

[Bug] [Relay] [ONNX] Incorrect shape inference of Squeeze in DynamicToStatic #17050

@shaoyuyoung

Description

@shaoyuyoung

Description

This torch model has only two ops: ReflectionPad3d and squeeze

Firstly, I try to export the torch model to onnx model.
Then I get the below.
model onnx

Onnx does its unique operation on the model.
We can find that this is a dynamic graph which contains if branch structure because of the squeeze operator.

ONNX thinks this model is valid.
However, When I used relay to convert the model, I met shape mismatch error. The correct shape should be Tensor[(13, 1, 1, 1), float32] but TVM got Tensor[(13, 13, 1, 1), float32].

(I think maybe) TVM has some bugs in the DynamicToStatic :(

Code

import onnx
import torch
import torch.nn as nn
import torch.onnx
from tvm import relay, relax


def get_onnx_shape(onnx_model):
    input_shapes = {}
    for input in onnx_model.graph.input:
        shape = []
        for dim in input.type.tensor_type.shape.dim:
            if dim.dim_value > 0:
                shape.append(dim.dim_value)
            else:
                shape.append(1)

        input_shapes[input.name] = shape
    return input_shapes


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.pad = nn.ReflectionPad3d((0, 0, -43, 0, 0, -46))

    def forward(self, x):
        x = self.pad(x)
        x = torch.squeeze(x, dim=1)
        return x



model = Model()

input_tensor = torch.randn(13, 47, 44, 1)

onnx_file_path = "model.onnx"
torch.onnx.export(model,
                  input_tensor,
                  onnx_file_path,
                  export_params=True,
                  opset_version=14,
                  do_constant_folding=False,
                  input_names=['input'],
                  output_names=['output']
                  )

onnx_model = onnx.load("model.onnx")
shape_dict = get_onnx_shape(onnx_model)

mod, params = relay.frontend.from_onnx(
    onnx_model, shape_dict, freeze_params=True
)

Error Log

click to see the error log
TVMError: Traceback (most recent call last):
  20: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  19: tvm::transform::Pass::operator()(tvm::IRModule) const
  18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  17: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  16: _ZN3tvm7runtime13PackedFun
  15: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  14: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule)
  13: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&)
  12: tvm::transform::Pass::operator()(tvm::IRModule) const
  11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  10: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  9: tvm::transform::Pass::operator()(tvm::IRModule) const
  8: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: _ZN3tvm7runtime13PackedFun
  5: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  4: tvm::DiagnosticContext::Render()
  3: tvm::DiagnosticRenderer::Render(tvm::DiagnosticContext const&)
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::DiagnosticContext)>::AssignTypedLambda<tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1}>(tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::ReportAt(tvm::DiagnosticContext const&, std::ostream&, tvm::Span const&, tvm::Diagnostic const&)
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/src/ir/diagnostic.cc", line 264
TVMError: The source maps are not populated for this module. Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error reporting.
Error: The Relay type checker is unable to show the following types match:
  Tensor[(13, 13, 1, 1), float32]
  Tensor[(13, 1, 1, 1), float32]
In particular:
  dimension 1 conflicts: 13 does not match 1.

Environment

TVM d1ac1c0
ubuntu 20

cc @KJlaccHoeUM9l @shingjan

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions