Skip to content

[Bug] MatMul operator in TVM seems fragile #16891

@shaoyuyoung

Description

@shaoyuyoung

TVM seems to have strict restrictions on MatMul operator which means that it cannot use tensors with different shapes.

Look at this simple graph. In Pytorch and onnx, the model is correctly defined and the input and output shapes are exactly as shown below.
The evidence is here: https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
image

When I try to covert ONNX to TVM, I get an error indicating that the tensor shape is inconsistent. However, When converting Pytorch to TVM, everything is OK!

I guess one possible reason is that TorchScript plays a role in this but ONNX does not.

Moreover, look at the last line of the error message. I wonder why T.int64(1) is used here. It seems that TVM has a pretty fragile system of int64.

image

Expected behavior

Pass compilation as it can produce results in ONNX and PyTorch.

Actual behavior

Compilation failure

Traceback (most recent call last):
  18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  17: tvm::transform::Pass::operator()(tvm::IRModule) const
  16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  14: _ZN3tvm7runtime13PackedFun
  13: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  12: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule)
  11: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  2: tvm::relay::TypeSolver::Solve()
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: bool tvm::relay::BatchMatmulRel<tvm::relay::BatchMatmulAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/root/anaconda3/conda-bld/tvm-package_1701590675822/work/src/relay/op/nn/nn.h", line 212
InternalError: Check failed: (reporter->AssertEQ(xk, yk)) is false: BatchDot: shapes of x and y is inconsistent,  x shape=[T.int64(1), 5, 5], y shape=[5, 5, 4]

Environment

Operating System: Ubuntu 18
TVM:0.15
Torch: 2.1.1
ONNX: 1.15.0

Steps to reproduce

Here is the script:

import torch
import torch.nn as nn
import tvm
from tvm import relay
import onnx

class DirectMatMulModel(nn.Module):
    def __init__(self):
        super(DirectMatMulModel, self).__init__()

    def forward(self, x1, x2, y1, y2):
        result1 = torch.matmul(x1, x2)
        result2 = torch.matmul(y1, y2)
        final_result = torch.matmul(result1, result2)
        return final_result


torch_model = DirectMatMulModel().eval()

x1 = torch.randn(5, 1)
x2 = torch.randn(1)
y1 = torch.randn(5, 4, 5)
y2 = torch.randn(5)

scripted_model = torch.jit.trace(torch_model, (x1, x2, y1, y2))

torch.onnx.export(torch_model,
                      (x1, x2, y1, y2),
                      "direct_matmul_model.onnx",
                      export_params=True,
                      opset_version=12,
                      do_constant_folding=True,
                      input_names=['x1', 'x2', 'y1', 'y2'],
                      output_names=['output'])

onnx_model = onnx.load("direct_matmul_model.onnx")
onnx.checker.check_model(onnx_model)

def compile_onnx():
    mod_from_onnx, params_onnx = relay.frontend.from_onnx(onnx_model, shape={'x1': [5, 1], 'x2': [1], 'y1': [5, 4, 5], 'y2': [5]})
    with tvm.transform.PassContext(opt_level=4):
        executor = relay.build_module.create_executor(
            'graph', mod_from_onnx, tvm.cpu(), 'llvm', params_onnx
        ).evaluate()

def compile_torch():
    mod_from_torch, params_torch = relay.frontend.from_pytorch(scripted_model, input_infos=[('x1', [5, 1]), ('x2', [1]), ('y1', [5, 4, 5]), ('y2', [5])])
    with tvm.transform.PassContext(opt_level=4):
        executor = relay.build_module.create_executor(
            'graph', mod_from_torch, tvm.cpu(), 'llvm', params_torch
        ).evaluate()

try:
    compile_torch()
except Exception as e:
    print(f"torch fail\n {e}")

try:
    compile_onnx()
except Exception as e:
    print(f"onnx fail\n {e}")

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions