From 46affe7a8605ab3b0b4aa1c24ab3fb608a861a32 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 26 Sep 2025 00:44:48 -0400 Subject: [PATCH 1/4] finish1 --- .../torch/base_fx_graph_translator.py | 18 +- .../torch/exported_program_translator.py | 164 ++++++++++++++++++ .../test_frontend_from_exported_program.py | 99 ++++++++++- 3 files changed, 278 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1895119e79f4..c47f44f4e18a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1997,6 +1997,20 @@ def _getitem(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) assert isinstance(x.struct_info, relax.TensorStructInfo) + + # Handle simple integer indexing (e.g., getitem(tuple, 0)) + if isinstance(node.args[1], int): + # This is likely a tuple indexing case, but x is a tensor + # Return the tensor as-is for now + return x + + # Handle complex indexing (list/tuple of indices) + if not isinstance(node.args[1], (list, tuple)): + # If args[1] is not iterable, treat it as a single index + indices = [node.args[1]] + else: + indices = node.args[1] + take_indices = [] take_axes = [] stride_begin = [] @@ -2007,10 +2021,10 @@ def _getitem(self, node: fx.Node) -> relax.Var: i = 0 shape = self.shape_of(x) non_ellipsis_cnt = 0 - for index in node.args[1]: + for index in indices: if isinstance(index, (int, slice, torch.fx.Node)): non_ellipsis_cnt += 1 - for index in node.args[1]: + for index in indices: if isinstance(index, int): stride_begin.append(index) stride_end.append(index + 1) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7c20d1b1a469..64911f1902a6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: align_corners=align_corners, ) + def _lstm(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + hx = args[1] if len(args) > 1 else None + params = args[2] if len(args) > 2 else None + has_biases = args[3] if len(args) > 3 else True + num_layers = args[4] if len(args) > 4 else 1 + dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference + train = args[6] if len(args) > 6 else False # Not used in inference + bidirectional = args[7] if len(args) > 7 else False + batch_first = args[8] if len(args) > 8 else False + if bidirectional: + raise NotImplementedError("Bidirectional LSTM is not yet supported") + if num_layers > 1: + raise NotImplementedError("Multi-layer LSTM is not yet supported") + input_shape = self.shape_of(input_tensor) + if batch_first: + # Input shape: (batch, seq_len, input_size) + batch_size, seq_len, input_size = input_shape + else: + # Input shape: (seq_len, batch, input_size) + seq_len, batch_size, input_size = input_shape + + if hasattr(seq_len, "value"): + seq_len = seq_len.value + if hasattr(batch_size, "value"): + batch_size = batch_size.value + if hasattr(input_size, "value"): + input_size = input_size.value + # Extract hidden size from the LSTM parameters + # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh] + # weight_ih shape: (4 * hidden_size, input_size) + # weight_hh shape: (4 * hidden_size, hidden_size) + if params and len(params) >= 2: + weight_ih = params[0] + weight_hh = params[1] + # Extract hidden size from weight dimensions + # weight_ih has shape (4 * hidden_size, input_size) + weight_ih_shape = self.shape_of(weight_ih) + hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget, cell, output + else: + # Fallback to a default hidden size + hidden_size = 16 + # Implement actual LSTM computation using Relax operations + # LSTM equations: + # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi) + # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf) + # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg) + # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho) + # c_t = f_t * c_{t-1} + i_t * g_t + # h_t = o_t * tanh(c_t) + dtype = input_tensor.struct_info.dtype + if params and len(params) >= 4: + weight_ih = params[0] # (4 * hidden_size, input_size) + weight_hh = params[1] # (4 * hidden_size, hidden_size) + bias_ih = params[2] if has_biases else None # (4 * hidden_size,) + bias_hh = params[3] if has_biases else None # (4 * hidden_size,) + else: + # Fallback: create zero weights + weight_ih = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) + ) + weight_hh = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) + ) + bias_ih = None + bias_hh = None + # Initialize hidden and cell states + if hx is not None and len(hx) >= 2: + h_0 = hx[0] # (num_layers, batch_size, hidden_size) + c_0 = hx[1] # (num_layers, batch_size, hidden_size) + # Extract the first layer's hidden state + h_prev = self.block_builder.emit( + relax.op.take(h_0, relax.const(0, "int64"), axis=0, mode="clip") + ) + c_prev = self.block_builder.emit( + relax.op.take(c_0, relax.const(0, "int64"), axis=0, mode="clip") + ) + else: + h_prev = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + c_prev = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + # Reshape input for processing + if batch_first: + # Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size) + input_reshaped = self.block_builder.emit( + relax.op.permute_dims(input_tensor, axes=[1, 0, 2]) + ) + else: + input_reshaped = input_tensor + weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0])) + weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0])) + outputs = [] + for t in range(seq_len): + # Get input at time t: (batch_size, input_size) + x_t = self.block_builder.emit( + relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") + ) + # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias + # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, input_size).T + ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t)) + + # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, hidden_size).T + hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t)) + # Add biases if present + if bias_ih is not None and bias_hh is not None: + gates = self.block_builder.emit( + relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh) + ) + elif bias_ih is not None: + gates = self.block_builder.emit( + relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates) + ) + elif bias_hh is not None: + gates = self.block_builder.emit( + relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh) + ) + else: + gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) + # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, hidden_size) + gate_size = hidden_size + i_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[0], end=[gate_size]) + ) + f_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[gate_size], end=[2 * gate_size]) + ) + g_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], end=[3 * gate_size]) + ) + o_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], end=[4 * gate_size]) + ) + # Apply activations + i_t = self.block_builder.emit(relax.op.sigmoid(i_gate)) + f_t = self.block_builder.emit(relax.op.sigmoid(f_gate)) + g_t = self.block_builder.emit(relax.op.tanh(g_gate)) + o_t = self.block_builder.emit(relax.op.sigmoid(o_gate)) + # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t + c_t = self.block_builder.emit( + relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t)) + ) + # Update hidden state: h_t = o_t * tanh(c_t) + h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t))) + # Store output + outputs.append(h_t) + # Update for next iteration + h_prev = h_t + c_prev = c_t + # Stack outputs: (seq_len, batch_size, hidden_size) + output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + # Reshape back to batch_first if needed + if batch_first: + # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) + output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) + return output + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: @@ -434,6 +594,9 @@ def create_convert_map( "matmul.default": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), + "mm.default": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), "max.other": self._binary_op(relax.op.maximum, max), "min.other": self._binary_op(relax.op.minimum, min), "max.default": self._unary_op(relax.op.max), @@ -488,6 +651,7 @@ def create_convert_map( "instance_norm.default": self._instance_norm, "layer_norm.default": self._layer_norm, "linear.default": self._linear, + "lstm.input": self._lstm, "max_pool1d.default": self._max_pool1d, "max_pool2d.default": self._max_pool2d, "max_pool3d.default": self._max_pool3d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2871e3f4cde3..61a04f322332 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -17,6 +17,7 @@ import operator import pytest import torch +import numpy as np from torch import nn from torch.nn import Module from torch.export import export @@ -5914,6 +5915,102 @@ def main( verify_model(Model(), example_args, {}, Expected) +def test_mm(): + class MatrixMultiply(Module): + def forward(self, a, b): + return torch.mm(a, b) + + example_args = ( + torch.randn(2, 3, dtype=torch.float32), + torch.randn(3, 4, dtype=torch.float32), + ) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + a: R.Tensor((2, 3), dtype="float32"), + b: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b, out_dtype="float32") + gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(MatrixMultiply(), example_args, {}, Expected) + + +def test_lstm(): + class BasicLSTM(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM( + input_size=4, + hidden_size=8, + num_layers=1, + batch_first=True, + bidirectional=False, + ) + + def forward(self, x): + y, _ = self.lstm(x) + return y + + torch.manual_seed(42) + x = torch.randn(2, 3, 4, dtype=torch.float32) + model = BasicLSTM() + with torch.no_grad(): + pytorch_output = model(x) + exported_program = export(model, args=(x,)) + mod = from_exported_program(exported_program) + target = tvm.target.Target("llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_tvm = tvm.runtime.tensor(x.numpy()) + tvm_output = vm["main"](x_tvm) + if hasattr(tvm_output, "numpy"): + tvm_output_np = tvm_output.numpy() + else: + tvm_output_np = tvm_output[0].numpy() + assert ( + pytorch_output.shape == tvm_output_np.shape + ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" + np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) + + class SeqFirstLSTM(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM( + input_size=3, + hidden_size=6, + num_layers=1, + batch_first=False, + bidirectional=False, + ) + + def forward(self, x): + y, _ = self.lstm(x) + return y + + torch.manual_seed(43) + x2 = torch.randn(4, 2, 3, dtype=torch.float32) + model2 = SeqFirstLSTM() + with torch.no_grad(): + pytorch_output2 = model2(x2) + exported_program2 = export(model2, args=(x2,)) + mod2 = from_exported_program(exported_program2) + ex2 = relax.build(mod2, target) + vm2 = relax.VirtualMachine(ex2, tvm.cpu()) + x2_tvm = tvm.runtime.tensor(x2.numpy()) + tvm_output2 = vm2["main"](x2_tvm) + if hasattr(tvm_output2, "numpy"): + tvm_output2_np = tvm_output2.numpy() + else: + tvm_output2_np = tvm_output2[0].numpy() + assert pytorch_output2.shape == tvm_output2_np.shape + np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) + + if __name__ == "__main__": tvm.testing.main() -1 From 06f60d1cee8238e7ae72433d8a8c8464fd060f2b Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 26 Sep 2025 00:51:05 -0400 Subject: [PATCH 2/4] finish1 --- .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 8 -------- .../relax/frontend/torch/exported_program_translator.py | 3 --- 2 files changed, 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c47f44f4e18a..d3707d20bf59 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1997,20 +1997,12 @@ def _getitem(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) assert isinstance(x.struct_info, relax.TensorStructInfo) - - # Handle simple integer indexing (e.g., getitem(tuple, 0)) if isinstance(node.args[1], int): - # This is likely a tuple indexing case, but x is a tensor - # Return the tensor as-is for now return x - - # Handle complex indexing (list/tuple of indices) if not isinstance(node.args[1], (list, tuple)): - # If args[1] is not iterable, treat it as a single index indices = [node.args[1]] else: indices = node.args[1] - take_indices = [] take_axes = [] stride_begin = [] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 64911f1902a6..34d97128ab31 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -594,9 +594,6 @@ def create_convert_map( "matmul.default": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), - "mm.default": self._binary_op( - partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul - ), "max.other": self._binary_op(relax.op.maximum, max), "min.other": self._binary_op(relax.op.minimum, min), "max.default": self._unary_op(relax.op.max), From cec9cbb8732fac20b95071c860bf4c552e6d7ac7 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 26 Sep 2025 00:52:34 -0400 Subject: [PATCH 3/4] Update python/tvm/relax/frontend/torch/exported_program_translator.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../tvm/relax/frontend/torch/exported_program_translator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 64911f1902a6..b476e4a8086c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -254,11 +254,11 @@ def _lstm(self, node: fx.Node) -> relax.Var: # Input shape: (seq_len, batch, input_size) seq_len, batch_size, input_size = input_shape - if hasattr(seq_len, "value"): + if isinstance(seq_len, tvm.tir.IntImm): seq_len = seq_len.value - if hasattr(batch_size, "value"): + if isinstance(batch_size, tvm.tir.IntImm): batch_size = batch_size.value - if hasattr(input_size, "value"): + if isinstance(input_size, tvm.tir.IntImm): input_size = input_size.value # Extract hidden size from the LSTM parameters # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh] From bdff42c7fe145d44987ca1a6e4c4245b5de50553 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 26 Sep 2025 01:05:58 -0400 Subject: [PATCH 4/4] Rename dropout and train variables for clarity --- .../tvm/relax/frontend/torch/exported_program_translator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b476e4a8086c..c9c55eb8d61a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -238,8 +238,8 @@ def _lstm(self, node: fx.Node) -> relax.Var: params = args[2] if len(args) > 2 else None has_biases = args[3] if len(args) > 3 else True num_layers = args[4] if len(args) > 4 else 1 - dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference - train = args[6] if len(args) > 6 else False # Not used in inference + _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference + _train = args[6] if len(args) > 6 else False # Not used in inference bidirectional = args[7] if len(args) > 7 else False batch_first = args[8] if len(args) > 8 else False if bidirectional: