-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
-
unbindproducing a single-element tuple should be represented as a tuple in Relax IR, not misinterpreted as a tensor. -
The frontend should either:
-
Correctly lower to Tuple with one element (each tensor shaped
(3,)in this case), or -
Gracefully reject with a clear Python exception, not an internal assertion failure.
-
Actual behavior
When importing a PyTorch torch.exported program into TVM Relax, if the model applies unbind(dim=0) to a tensor with dimension=1 along that axis, the frontend crashes with:
Check failed: (opt) is false: The struct info of Tuple must be TupleStructInfo,
but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")
This indicates a mismatch between expected TupleStructInfo vs. actual TensorStructInfo.
[INFO] start importing exported program into TVM Relax...
[REPRODUCED] Caught exception while importing:
Check failed: (opt) is false: The struct info of Tuple must be TupleStructInfo,
but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")
tvm.error.InternalError: Check failed: (opt) is false: The struct info of Tuple must be TupleStructInfo, but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")
[...]/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
Steps to reproduce
import torch
import torch.nn as nn
# Minimal model: gather -> max -> unbind(0)
class MyModel(nn.Module):
def forward(self, x, y):
# x: (2, 3), y: (1, 3)
indices = torch.zeros((1,) + x.size()[1:], dtype=torch.long, device=x.device)
x_gathered = torch.gather(x, 0, indices) # (1, 3)
compared = torch.max(x_gathered, y) # (1, 3)
outs = compared.unbind(0) # tuple of length 1, each (3,)
return outs
def get_inputs():
torch.manual_seed(0)
x = torch.randn(2, 3)
y = torch.randn(1, 3)
return x, y
if __name__ == "__main__":
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
m = MyModel().eval()
args = get_inputs()
ep = torch_export(m, args)
print("[INFO] start importing exported program into TVM Relax...")
try:
mod = from_exported_program(ep)
print("[UNEXPECTED] Import succeeded (no error).")
except Exception as e:
print("[REPRODUCED] Caught exception while importing:")
print(e)
raiseTriage
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug