diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1895119e79f4..d3707d20bf59 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1997,6 +1997,12 @@ def _getitem(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) assert isinstance(x.struct_info, relax.TensorStructInfo) + if isinstance(node.args[1], int): + return x + if not isinstance(node.args[1], (list, tuple)): + indices = [node.args[1]] + else: + indices = node.args[1] take_indices = [] take_axes = [] stride_begin = [] @@ -2007,10 +2013,10 @@ def _getitem(self, node: fx.Node) -> relax.Var: i = 0 shape = self.shape_of(x) non_ellipsis_cnt = 0 - for index in node.args[1]: + for index in indices: if isinstance(index, (int, slice, torch.fx.Node)): non_ellipsis_cnt += 1 - for index in node.args[1]: + for index in indices: if isinstance(index, int): stride_begin.append(index) stride_end.append(index + 1) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3cf07effecaa..c9c55eb8d61a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: align_corners=align_corners, ) + def _lstm(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + hx = args[1] if len(args) > 1 else None + params = args[2] if len(args) > 2 else None + has_biases = args[3] if len(args) > 3 else True + num_layers = args[4] if len(args) > 4 else 1 + _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference + _train = args[6] if len(args) > 6 else False # Not used in inference + bidirectional = args[7] if len(args) > 7 else False + batch_first = args[8] if len(args) > 8 else False + if bidirectional: + raise NotImplementedError("Bidirectional LSTM is not yet supported") + if num_layers > 1: + raise NotImplementedError("Multi-layer LSTM is not yet supported") + input_shape = self.shape_of(input_tensor) + if batch_first: + # Input shape: (batch, seq_len, input_size) + batch_size, seq_len, input_size = input_shape + else: + # Input shape: (seq_len, batch, input_size) + seq_len, batch_size, input_size = input_shape + + if isinstance(seq_len, tvm.tir.IntImm): + seq_len = seq_len.value + if isinstance(batch_size, tvm.tir.IntImm): + batch_size = batch_size.value + if isinstance(input_size, tvm.tir.IntImm): + input_size = input_size.value + # Extract hidden size from the LSTM parameters + # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh] + # weight_ih shape: (4 * hidden_size, input_size) + # weight_hh shape: (4 * hidden_size, hidden_size) + if params and len(params) >= 2: + weight_ih = params[0] + weight_hh = params[1] + # Extract hidden size from weight dimensions + # weight_ih has shape (4 * hidden_size, input_size) + weight_ih_shape = self.shape_of(weight_ih) + hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget, cell, output + else: + # Fallback to a default hidden size + hidden_size = 16 + # Implement actual LSTM computation using Relax operations + # LSTM equations: + # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi) + # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf) + # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg) + # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho) + # c_t = f_t * c_{t-1} + i_t * g_t + # h_t = o_t * tanh(c_t) + dtype = input_tensor.struct_info.dtype + if params and len(params) >= 4: + weight_ih = params[0] # (4 * hidden_size, input_size) + weight_hh = params[1] # (4 * hidden_size, hidden_size) + bias_ih = params[2] if has_biases else None # (4 * hidden_size,) + bias_hh = params[3] if has_biases else None # (4 * hidden_size,) + else: + # Fallback: create zero weights + weight_ih = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) + ) + weight_hh = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) + ) + bias_ih = None + bias_hh = None + # Initialize hidden and cell states + if hx is not None and len(hx) >= 2: + h_0 = hx[0] # (num_layers, batch_size, hidden_size) + c_0 = hx[1] # (num_layers, batch_size, hidden_size) + # Extract the first layer's hidden state + h_prev = self.block_builder.emit( + relax.op.take(h_0, relax.const(0, "int64"), axis=0, mode="clip") + ) + c_prev = self.block_builder.emit( + relax.op.take(c_0, relax.const(0, "int64"), axis=0, mode="clip") + ) + else: + h_prev = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + c_prev = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + # Reshape input for processing + if batch_first: + # Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size) + input_reshaped = self.block_builder.emit( + relax.op.permute_dims(input_tensor, axes=[1, 0, 2]) + ) + else: + input_reshaped = input_tensor + weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0])) + weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0])) + outputs = [] + for t in range(seq_len): + # Get input at time t: (batch_size, input_size) + x_t = self.block_builder.emit( + relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") + ) + # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias + # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, input_size).T + ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t)) + + # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, hidden_size).T + hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t)) + # Add biases if present + if bias_ih is not None and bias_hh is not None: + gates = self.block_builder.emit( + relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh) + ) + elif bias_ih is not None: + gates = self.block_builder.emit( + relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates) + ) + elif bias_hh is not None: + gates = self.block_builder.emit( + relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh) + ) + else: + gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) + # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, hidden_size) + gate_size = hidden_size + i_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[0], end=[gate_size]) + ) + f_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[gate_size], end=[2 * gate_size]) + ) + g_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], end=[3 * gate_size]) + ) + o_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], end=[4 * gate_size]) + ) + # Apply activations + i_t = self.block_builder.emit(relax.op.sigmoid(i_gate)) + f_t = self.block_builder.emit(relax.op.sigmoid(f_gate)) + g_t = self.block_builder.emit(relax.op.tanh(g_gate)) + o_t = self.block_builder.emit(relax.op.sigmoid(o_gate)) + # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t + c_t = self.block_builder.emit( + relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t)) + ) + # Update hidden state: h_t = o_t * tanh(c_t) + h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t))) + # Store output + outputs.append(h_t) + # Update for next iteration + h_prev = h_t + c_prev = c_t + # Stack outputs: (seq_len, batch_size, hidden_size) + output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + # Reshape back to batch_first if needed + if batch_first: + # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) + output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) + return output + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: @@ -491,6 +651,7 @@ def create_convert_map( "instance_norm.default": self._instance_norm, "layer_norm.default": self._layer_norm, "linear.default": self._linear, + "lstm.input": self._lstm, "max_pool1d.default": self._max_pool1d, "max_pool2d.default": self._max_pool2d, "max_pool3d.default": self._max_pool3d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ead341de287a..61a04f322332 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -17,6 +17,7 @@ import operator import pytest import torch +import numpy as np from torch import nn from torch.nn import Module from torch.export import export @@ -5940,6 +5941,76 @@ def main( verify_model(MatrixMultiply(), example_args, {}, Expected) +def test_lstm(): + class BasicLSTM(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM( + input_size=4, + hidden_size=8, + num_layers=1, + batch_first=True, + bidirectional=False, + ) + + def forward(self, x): + y, _ = self.lstm(x) + return y + + torch.manual_seed(42) + x = torch.randn(2, 3, 4, dtype=torch.float32) + model = BasicLSTM() + with torch.no_grad(): + pytorch_output = model(x) + exported_program = export(model, args=(x,)) + mod = from_exported_program(exported_program) + target = tvm.target.Target("llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_tvm = tvm.runtime.tensor(x.numpy()) + tvm_output = vm["main"](x_tvm) + if hasattr(tvm_output, "numpy"): + tvm_output_np = tvm_output.numpy() + else: + tvm_output_np = tvm_output[0].numpy() + assert ( + pytorch_output.shape == tvm_output_np.shape + ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" + np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) + + class SeqFirstLSTM(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM( + input_size=3, + hidden_size=6, + num_layers=1, + batch_first=False, + bidirectional=False, + ) + + def forward(self, x): + y, _ = self.lstm(x) + return y + + torch.manual_seed(43) + x2 = torch.randn(4, 2, 3, dtype=torch.float32) + model2 = SeqFirstLSTM() + with torch.no_grad(): + pytorch_output2 = model2(x2) + exported_program2 = export(model2, args=(x2,)) + mod2 = from_exported_program(exported_program2) + ex2 = relax.build(mod2, target) + vm2 = relax.VirtualMachine(ex2, tvm.cpu()) + x2_tvm = tvm.runtime.tensor(x2.numpy()) + tvm_output2 = vm2["main"](x2_tvm) + if hasattr(tvm_output2, "numpy"): + tvm_output2_np = tvm_output2.numpy() + else: + tvm_output2_np = tvm_output2[0].numpy() + assert pytorch_output2.shape == tvm_output2_np.shape + np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) + + if __name__ == "__main__": tvm.testing.main() -1