Skip to content

[Bug][Relax][Torch] from_exported_program fails on unbind producing single-element tuple #18338

@tinywisdom

Description

@tinywisdom

Expected behavior

  • unbind producing a single-element tuple should be represented as a tuple in Relax IR, not misinterpreted as a tensor.

  • The frontend should either:

    • Correctly lower to Tuple with one element (each tensor shaped (3,) in this case), or

    • Gracefully reject with a clear Python exception, not an internal assertion failure.

Actual behavior

When importing a PyTorch torch.exported program into TVM Relax, if the model applies unbind(dim=0) to a tensor with dimension=1 along that axis, the frontend crashes with:

Check failed: (opt) is false: The struct info of Tuple must be TupleStructInfo,
but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")

This indicates a mismatch between expected TupleStructInfo vs. actual TensorStructInfo.

[INFO] start importing exported program into TVM Relax...
[REPRODUCED] Caught exception while importing:
Check failed: (opt) is false: The struct info of Tuple must be TupleStructInfo,
but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")

tvm.error.InternalError: Check failed: (opt) is false: The struct info of Tuple must be TupleStructInfo, but expression lv3 has struct info R.Tensor((1, 3), dtype="float32")
[...]/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!

Environment

  • OS: (Ubuntu 22.04.4 LTS (x86_64))
  • TVM version: (release v0.21.0)
  • Python: (3.10.16)
  • LLVM: (17.0.6)

Steps to reproduce

import torch
import torch.nn as nn

# Minimal model: gather -> max -> unbind(0)
class MyModel(nn.Module):
    def forward(self, x, y):
        # x: (2, 3), y: (1, 3)
        indices = torch.zeros((1,) + x.size()[1:], dtype=torch.long, device=x.device)
        x_gathered = torch.gather(x, 0, indices)   # (1, 3)
        compared = torch.max(x_gathered, y)        # (1, 3)
        outs = compared.unbind(0)                  # tuple of length 1, each (3,)
        return outs

def get_inputs():
    torch.manual_seed(0)
    x = torch.randn(2, 3)
    y = torch.randn(1, 3)
    return x, y

if __name__ == "__main__":
    from torch.export import export as torch_export
    from tvm.relax.frontend.torch import from_exported_program

    m = MyModel().eval()
    args = get_inputs()

    ep = torch_export(m, args)

    print("[INFO] start importing exported program into TVM Relax...")
    try:
        mod = from_exported_program(ep)
        print("[UNEXPECTED] Import succeeded (no error).")
    except Exception as e:
        print("[REPRODUCED] Caught exception while importing:")
        print(e)
        raise

Triage

  • needs-triage
  • bug

cc @junrushao @shingjan

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions