Skip to content

[Feature Request][Relax][Torch] Add support for nn.LSTM (lstm.input) in from_exported_program #18340

@tinywisdom

Description

@tinywisdom

Converting a PyTorch torch.exported program into TVM Relax with from_exported_program fails when the model contains nn.LSTM. The frontend reports:

AssertionError: Unsupported function types ['lstm.input']

This indicates that lstm.input (the op emitted by torch.export for nn.LSTM) is currently not supported in the TVM Relax PyTorch frontend.

Expected behavior

  • The Relax Torch frontend should lower nn.LSTM (emitted as lstm.input in torch.export) to a supported Relax representation:

    • Either a high-level RNN/LSTM composite (if available), or

    • A lower-level decomposition into primitive ops (matmul/elementwise/activations) wrapped as a Relax subgraph / call_tir where appropriate.

  • If certain LSTM configurations are not yet supported (e.g., bidirectional, multi-layer, projections), the importer should:

    • Accept supported subsets, and

    • Emit a clear Python exception for unsupported variants with guidance.

Actual behavior

PyTorch eager OK, y.shape = (2, 4, 16)
ExportedProgram created.
Traceback (most recent call last):
  ...
  File ".../base_fx_graph_translator.py", line 116, in _check_unsupported_func_type
    assert not missing_func_types, f"Unsupported function types {missing_func_types}"
AssertionError: Unsupported function types ['lstm.input']

Environment

  • OS: (Ubuntu 22.04.4 LTS (x86_64))
  • TVM version: (release v0.21.0)
  • Python: (3.10.16)
  • LLVM: (17.0.6)

Steps to reproduce

import torch
import torch.nn as nn
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program

class M(nn.Module):
    def __init__(self, input_size=8, hidden_size=16, num_layers=1, bidirectional=False):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
        )

    def forward(self, x):
        # Only return the output sequence; drop (h_n, c_n)
        y, _ = self.lstm(x)
        return y

def main():
    torch.manual_seed(0)
    m = M().eval()

    # Minimal input: B=2, T=4, C=8
    x = torch.randn(2, 4, 8, dtype=torch.float32)

    # 1) Sanity check in eager
    with torch.inference_mode():
        y = m(x)
    print("PyTorch eager OK, y.shape =", tuple(y.shape))

    # 2) Export to ExportedProgram
    ep = torch_export(m, (x,))
    print("ExportedProgram created.")

    # 3) Import into TVM Relax — triggers unsupported function type
    _ = from_exported_program(ep)

if __name__ == "__main__":
    main()

Triage

  • needs-triage
  • bug

cc @junrushao @shingjan

Metadata

Metadata

Assignees

Labels

needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions