-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Converting a PyTorch torch.exported program into TVM Relax with from_exported_program fails when the model contains nn.LSTM. The frontend reports:
AssertionError: Unsupported function types ['lstm.input']
This indicates that lstm.input (the op emitted by torch.export for nn.LSTM) is currently not supported in the TVM Relax PyTorch frontend.
Expected behavior
-
The Relax Torch frontend should lower
nn.LSTM(emitted aslstm.inputintorch.export) to a supported Relax representation:-
Either a high-level RNN/LSTM composite (if available), or
-
A lower-level decomposition into primitive ops (matmul/elementwise/activations) wrapped as a Relax subgraph /
call_tirwhere appropriate.
-
-
If certain LSTM configurations are not yet supported (e.g., bidirectional, multi-layer, projections), the importer should:
-
Accept supported subsets, and
-
Emit a clear Python exception for unsupported variants with guidance.
-
Actual behavior
PyTorch eager OK, y.shape = (2, 4, 16)
ExportedProgram created.
Traceback (most recent call last):
...
File ".../base_fx_graph_translator.py", line 116, in _check_unsupported_func_type
assert not missing_func_types, f"Unsupported function types {missing_func_types}"
AssertionError: Unsupported function types ['lstm.input']
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
Steps to reproduce
import torch
import torch.nn as nn
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
class M(nn.Module):
def __init__(self, input_size=8, hidden_size=16, num_layers=1, bidirectional=False):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
)
def forward(self, x):
# Only return the output sequence; drop (h_n, c_n)
y, _ = self.lstm(x)
return y
def main():
torch.manual_seed(0)
m = M().eval()
# Minimal input: B=2, T=4, C=8
x = torch.randn(2, 4, 8, dtype=torch.float32)
# 1) Sanity check in eager
with torch.inference_mode():
y = m(x)
print("PyTorch eager OK, y.shape =", tuple(y.shape))
# 2) Export to ExportedProgram
ep = torch_export(m, (x,))
print("ExportedProgram created.")
# 3) Import into TVM Relax — triggers unsupported function type
_ = from_exported_program(ep)
if __name__ == "__main__":
main()Triage
- needs-triage
- bug