Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,6 +1997,12 @@ def _getitem(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.TupleGetItem(x, node.args[1]))

assert isinstance(x.struct_info, relax.TensorStructInfo)
if isinstance(node.args[1], int):
return x
if not isinstance(node.args[1], (list, tuple)):
indices = [node.args[1]]
else:
indices = node.args[1]
Comment on lines +2002 to +2005
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new logic for handling integer indexing on tensors is incorrect. It returns the tensor itself, which breaks the semantics of tensor indexing. For a tensor x, x[0] should return the first slice along axis 0, which has a reduced rank. The current implementation returns x unmodified. Since _getitem is a general-purpose function, this change can cause incorrect behavior for other operators that rely on it.

This seems to be a workaround for an incomplete _lstm implementation. The correct fix should be in the _lstm operator implementation to return a proper tuple output, and this logic should be removed from _getitem.

take_indices = []
take_axes = []
stride_begin = []
Expand All @@ -2007,10 +2013,10 @@ def _getitem(self, node: fx.Node) -> relax.Var:
i = 0
shape = self.shape_of(x)
non_ellipsis_cnt = 0
for index in node.args[1]:
for index in indices:
if isinstance(index, (int, slice, torch.fx.Node)):
non_ellipsis_cnt += 1
for index in node.args[1]:
for index in indices:
if isinstance(index, int):
stride_begin.append(index)
stride_end.append(index + 1)
Expand Down
161 changes: 161 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)

def _lstm(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
hx = args[1] if len(args) > 1 else None
params = args[2] if len(args) > 2 else None
has_biases = args[3] if len(args) > 3 else True
num_layers = args[4] if len(args) > 4 else 1
_dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
_train = args[6] if len(args) > 6 else False # Not used in inference
bidirectional = args[7] if len(args) > 7 else False
batch_first = args[8] if len(args) > 8 else False
if bidirectional:
raise NotImplementedError("Bidirectional LSTM is not yet supported")
if num_layers > 1:
raise NotImplementedError("Multi-layer LSTM is not yet supported")
input_shape = self.shape_of(input_tensor)
if batch_first:
# Input shape: (batch, seq_len, input_size)
batch_size, seq_len, input_size = input_shape
else:
# Input shape: (seq_len, batch, input_size)
seq_len, batch_size, input_size = input_shape

if isinstance(seq_len, tvm.tir.IntImm):
seq_len = seq_len.value
if isinstance(batch_size, tvm.tir.IntImm):
batch_size = batch_size.value
if isinstance(input_size, tvm.tir.IntImm):
input_size = input_size.value
# Extract hidden size from the LSTM parameters
# The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
# weight_ih shape: (4 * hidden_size, input_size)
# weight_hh shape: (4 * hidden_size, hidden_size)
if params and len(params) >= 2:
weight_ih = params[0]
weight_hh = params[1]
# Extract hidden size from weight dimensions
# weight_ih has shape (4 * hidden_size, input_size)
weight_ih_shape = self.shape_of(weight_ih)
hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget, cell, output
else:
# Fallback to a default hidden size
hidden_size = 16
Comment on lines +274 to +276
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fallback logic for when LSTM parameters are not provided is problematic. It defaults to hidden_size = 16. This can lead to silent correctness issues and hard-to-debug errors. It would be better to raise a ValueError if the parameters are not available to determine hidden_size, as a valid LSTM layer must have weights.

Suggested change
else:
# Fallback to a default hidden size
hidden_size = 16
else:
raise ValueError("Cannot determine hidden_size. LSTM params (weights) are required.")

# Implement actual LSTM computation using Relax operations
# LSTM equations:
# i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi)
# f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf)
# g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg)
# o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho)
# c_t = f_t * c_{t-1} + i_t * g_t
# h_t = o_t * tanh(c_t)
dtype = input_tensor.struct_info.dtype
if params and len(params) >= 4:
weight_ih = params[0] # (4 * hidden_size, input_size)
weight_hh = params[1] # (4 * hidden_size, hidden_size)
bias_ih = params[2] if has_biases else None # (4 * hidden_size,)
bias_hh = params[3] if has_biases else None # (4 * hidden_size,)
else:
# Fallback: create zero weights
weight_ih = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype)
)
weight_hh = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype)
)
bias_ih = None
bias_hh = None
Comment on lines +291 to +300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Creating zero-tensors for weights as a fallback is problematic. This can lead to silent correctness issues where the model compiles but produces incorrect (zero) outputs. It's better to raise an error if weights are not provided, as they are essential for a functional LSTM layer.

Suggested change
else:
# Fallback: create zero weights
weight_ih = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype)
)
weight_hh = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype)
)
bias_ih = None
bias_hh = None
else:
raise ValueError("LSTM params (weights) are required.")

# Initialize hidden and cell states
if hx is not None and len(hx) >= 2:
h_0 = hx[0] # (num_layers, batch_size, hidden_size)
c_0 = hx[1] # (num_layers, batch_size, hidden_size)
# Extract the first layer's hidden state
h_prev = self.block_builder.emit(
relax.op.take(h_0, relax.const(0, "int64"), axis=0, mode="clip")
)
c_prev = self.block_builder.emit(
relax.op.take(c_0, relax.const(0, "int64"), axis=0, mode="clip")
)
else:
h_prev = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
c_prev = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
# Reshape input for processing
if batch_first:
# Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size)
input_reshaped = self.block_builder.emit(
relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
)
else:
input_reshaped = input_tensor
weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0]))
weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0]))
outputs = []
for t in range(seq_len):
# Get input at time t: (batch_size, input_size)
x_t = self.block_builder.emit(
relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip")
)
# Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
# Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, input_size).T
ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))

# Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, hidden_size).T
hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
# Add biases if present
if bias_ih is not None and bias_hh is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh)
)
elif bias_ih is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
)
elif bias_hh is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
)
else:
gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
Comment on lines +342 to +355
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for adding biases is quite verbose with multiple if/elif/else branches. This can be simplified for better readability and maintainability. You can calculate the total gates first, and then conditionally add the biases.

Suggested change
if bias_ih is not None and bias_hh is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh)
)
elif bias_ih is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
)
elif bias_hh is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
)
else:
gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
if bias_ih is not None:
gates = self.block_builder.emit(relax.op.add(gates, bias_ih))
if bias_hh is not None:
gates = self.block_builder.emit(relax.op.add(gates, bias_hh))

# Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, hidden_size)
gate_size = hidden_size
i_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[0], end=[gate_size])
)
f_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[gate_size], end=[2 * gate_size])
)
g_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], end=[3 * gate_size])
)
o_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], end=[4 * gate_size])
)
Comment on lines +358 to +369
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The four gates (input, forget, cell, output) are split from the concatenated gates tensor using four separate strided_slice operations. This can be done more efficiently and concisely using a single relax.op.split operation, which would also improve readability.

            gate_tuple = self.block_builder.emit(relax.op.split(gates, 4, axis=1))
            i_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 0))
            f_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 1))
            g_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 2))
            o_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 3))

# Apply activations
i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
g_t = self.block_builder.emit(relax.op.tanh(g_gate))
o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
# Update cell state: c_t = f_t * c_{t-1} + i_t * g_t
c_t = self.block_builder.emit(
relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t))
)
# Update hidden state: h_t = o_t * tanh(c_t)
h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t)))
# Store output
outputs.append(h_t)
# Update for next iteration
h_prev = h_t
c_prev = c_t
# Stack outputs: (seq_len, batch_size, hidden_size)
output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
# Reshape back to batch_first if needed
if batch_first:
# (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size)
output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2]))
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _lstm implementation is incomplete. It only returns the output sequence but not the final hidden and cell states, which are part of the standard torch.nn.LSTM output (output, (h_n, c_n)). This will lead to incorrect behavior for models that use these states. The function should be updated to return a tuple containing the output sequence and the final hidden/cell states to fully match the PyTorch operator's behavior.


########## Manipulation ##########

def _narrow(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -491,6 +651,7 @@ def create_convert_map(
"instance_norm.default": self._instance_norm,
"layer_norm.default": self._layer_norm,
"linear.default": self._linear,
"lstm.input": self._lstm,
"max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
"max_pool3d.default": self._max_pool3d,
Expand Down
73 changes: 72 additions & 1 deletion tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import operator
import pytest
import torch
import numpy as np
from torch import nn
from torch.nn import Module
from torch.export import export
Expand Down Expand Up @@ -5940,6 +5941,76 @@ def main(
verify_model(MatrixMultiply(), example_args, {}, Expected)


def test_lstm():
class BasicLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(
input_size=4,
hidden_size=8,
num_layers=1,
batch_first=True,
bidirectional=False,
)

def forward(self, x):
y, _ = self.lstm(x)
return y

torch.manual_seed(42)
x = torch.randn(2, 3, 4, dtype=torch.float32)
model = BasicLSTM()
with torch.no_grad():
pytorch_output = model(x)
exported_program = export(model, args=(x,))
mod = from_exported_program(exported_program)
target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
x_tvm = tvm.runtime.tensor(x.numpy())
tvm_output = vm["main"](x_tvm)
if hasattr(tvm_output, "numpy"):
tvm_output_np = tvm_output.numpy()
else:
tvm_output_np = tvm_output[0].numpy()
assert (
pytorch_output.shape == tvm_output_np.shape
), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}"
np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5)

class SeqFirstLSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(
input_size=3,
hidden_size=6,
num_layers=1,
batch_first=False,
bidirectional=False,
)

def forward(self, x):
y, _ = self.lstm(x)
return y

torch.manual_seed(43)
x2 = torch.randn(4, 2, 3, dtype=torch.float32)
model2 = SeqFirstLSTM()
with torch.no_grad():
pytorch_output2 = model2(x2)
exported_program2 = export(model2, args=(x2,))
mod2 = from_exported_program(exported_program2)
ex2 = relax.build(mod2, target)
vm2 = relax.VirtualMachine(ex2, tvm.cpu())
x2_tvm = tvm.runtime.tensor(x2.numpy())
tvm_output2 = vm2["main"](x2_tvm)
if hasattr(tvm_output2, "numpy"):
tvm_output2_np = tvm_output2.numpy()
else:
tvm_output2_np = tvm_output2[0].numpy()
assert pytorch_output2.shape == tvm_output2_np.shape
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5)
Comment on lines +5944 to +6012
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test_lstm function contains a significant amount of duplicated code for testing the batch_first=True and batch_first=False cases. This can be refactored into a helper function to improve readability and maintainability. The helper function could take the model and input tensor as arguments and perform the verification logic.



if __name__ == "__main__":
tvm.testing.main()
1
Loading