-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
TVM seems to have strict restrictions on MatMul operator which means that it cannot use tensors with different shapes.
Look at this simple graph. In Pytorch and onnx, the model is correctly defined and the input and output shapes are exactly as shown below.
The evidence is here: https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul

When I try to covert ONNX to TVM, I get an error indicating that the tensor shape is inconsistent. However, When converting Pytorch to TVM, everything is OK!
I guess one possible reason is that TorchScript plays a role in this but ONNX does not.
Moreover, look at the last line of the error message. I wonder why T.int64(1) is used here. It seems that TVM has a pretty fragile system of int64.
Expected behavior
Pass compilation as it can produce results in ONNX and PyTorch.
Actual behavior
Compilation failure
Traceback (most recent call last):
18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
17: tvm::transform::Pass::operator()(tvm::IRModule) const
16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
14: _ZN3tvm7runtime13PackedFun
13: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
12: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule)
11: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&)
10: tvm::transform::Pass::operator()(tvm::IRModule) const
9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
8: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
7: tvm::transform::Pass::operator()(tvm::IRModule) const
6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
2: tvm::relay::TypeSolver::Solve()
1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
0: bool tvm::relay::BatchMatmulRel<tvm::relay::BatchMatmulAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
File "/root/anaconda3/conda-bld/tvm-package_1701590675822/work/src/relay/op/nn/nn.h", line 212
InternalError: Check failed: (reporter->AssertEQ(xk, yk)) is false: BatchDot: shapes of x and y is inconsistent, x shape=[T.int64(1), 5, 5], y shape=[5, 5, 4]
Environment
Operating System: Ubuntu 18
TVM:0.15
Torch: 2.1.1
ONNX: 1.15.0
Steps to reproduce
Here is the script:
import torch
import torch.nn as nn
import tvm
from tvm import relay
import onnx
class DirectMatMulModel(nn.Module):
def __init__(self):
super(DirectMatMulModel, self).__init__()
def forward(self, x1, x2, y1, y2):
result1 = torch.matmul(x1, x2)
result2 = torch.matmul(y1, y2)
final_result = torch.matmul(result1, result2)
return final_result
torch_model = DirectMatMulModel().eval()
x1 = torch.randn(5, 1)
x2 = torch.randn(1)
y1 = torch.randn(5, 4, 5)
y2 = torch.randn(5)
scripted_model = torch.jit.trace(torch_model, (x1, x2, y1, y2))
torch.onnx.export(torch_model,
(x1, x2, y1, y2),
"direct_matmul_model.onnx",
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['x1', 'x2', 'y1', 'y2'],
output_names=['output'])
onnx_model = onnx.load("direct_matmul_model.onnx")
onnx.checker.check_model(onnx_model)
def compile_onnx():
mod_from_onnx, params_onnx = relay.frontend.from_onnx(onnx_model, shape={'x1': [5, 1], 'x2': [1], 'y1': [5, 4, 5], 'y2': [5]})
with tvm.transform.PassContext(opt_level=4):
executor = relay.build_module.create_executor(
'graph', mod_from_onnx, tvm.cpu(), 'llvm', params_onnx
).evaluate()
def compile_torch():
mod_from_torch, params_torch = relay.frontend.from_pytorch(scripted_model, input_infos=[('x1', [5, 1]), ('x2', [1]), ('y1', [5, 4, 5]), ('y2', [5])])
with tvm.transform.PassContext(opt_level=4):
executor = relay.build_module.create_executor(
'graph', mod_from_torch, tvm.cpu(), 'llvm', params_torch
).evaluate()
try:
compile_torch()
except Exception as e:
print(f"torch fail\n {e}")
try:
compile_onnx()
except Exception as e:
print(f"onnx fail\n {e}")Triage
- needs-triage
