-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Relax][PyTorch] Support lstm op for ExportedProgram importer #18346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
46affe7
06f60d1
57d2062
cec9cbb
bdff42c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -231,6 +231,166 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: | |||||||||||||||||||||||||||||||||||||||
| align_corners=align_corners, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _lstm(self, node: fx.Node) -> relax.Var: | ||||||||||||||||||||||||||||||||||||||||
| args = self.retrieve_args(node) | ||||||||||||||||||||||||||||||||||||||||
| input_tensor = args[0] | ||||||||||||||||||||||||||||||||||||||||
| hx = args[1] if len(args) > 1 else None | ||||||||||||||||||||||||||||||||||||||||
| params = args[2] if len(args) > 2 else None | ||||||||||||||||||||||||||||||||||||||||
| has_biases = args[3] if len(args) > 3 else True | ||||||||||||||||||||||||||||||||||||||||
| num_layers = args[4] if len(args) > 4 else 1 | ||||||||||||||||||||||||||||||||||||||||
| _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference | ||||||||||||||||||||||||||||||||||||||||
| _train = args[6] if len(args) > 6 else False # Not used in inference | ||||||||||||||||||||||||||||||||||||||||
| bidirectional = args[7] if len(args) > 7 else False | ||||||||||||||||||||||||||||||||||||||||
| batch_first = args[8] if len(args) > 8 else False | ||||||||||||||||||||||||||||||||||||||||
| if bidirectional: | ||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError("Bidirectional LSTM is not yet supported") | ||||||||||||||||||||||||||||||||||||||||
| if num_layers > 1: | ||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError("Multi-layer LSTM is not yet supported") | ||||||||||||||||||||||||||||||||||||||||
| input_shape = self.shape_of(input_tensor) | ||||||||||||||||||||||||||||||||||||||||
| if batch_first: | ||||||||||||||||||||||||||||||||||||||||
| # Input shape: (batch, seq_len, input_size) | ||||||||||||||||||||||||||||||||||||||||
| batch_size, seq_len, input_size = input_shape | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| # Input shape: (seq_len, batch, input_size) | ||||||||||||||||||||||||||||||||||||||||
| seq_len, batch_size, input_size = input_shape | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if isinstance(seq_len, tvm.tir.IntImm): | ||||||||||||||||||||||||||||||||||||||||
| seq_len = seq_len.value | ||||||||||||||||||||||||||||||||||||||||
| if isinstance(batch_size, tvm.tir.IntImm): | ||||||||||||||||||||||||||||||||||||||||
| batch_size = batch_size.value | ||||||||||||||||||||||||||||||||||||||||
| if isinstance(input_size, tvm.tir.IntImm): | ||||||||||||||||||||||||||||||||||||||||
| input_size = input_size.value | ||||||||||||||||||||||||||||||||||||||||
| # Extract hidden size from the LSTM parameters | ||||||||||||||||||||||||||||||||||||||||
| # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh] | ||||||||||||||||||||||||||||||||||||||||
| # weight_ih shape: (4 * hidden_size, input_size) | ||||||||||||||||||||||||||||||||||||||||
| # weight_hh shape: (4 * hidden_size, hidden_size) | ||||||||||||||||||||||||||||||||||||||||
| if params and len(params) >= 2: | ||||||||||||||||||||||||||||||||||||||||
| weight_ih = params[0] | ||||||||||||||||||||||||||||||||||||||||
| weight_hh = params[1] | ||||||||||||||||||||||||||||||||||||||||
| # Extract hidden size from weight dimensions | ||||||||||||||||||||||||||||||||||||||||
| # weight_ih has shape (4 * hidden_size, input_size) | ||||||||||||||||||||||||||||||||||||||||
| weight_ih_shape = self.shape_of(weight_ih) | ||||||||||||||||||||||||||||||||||||||||
| hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget, cell, output | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| # Fallback to a default hidden size | ||||||||||||||||||||||||||||||||||||||||
| hidden_size = 16 | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+274
to
+276
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback logic for when LSTM parameters are not provided is problematic. It defaults to
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| # Implement actual LSTM computation using Relax operations | ||||||||||||||||||||||||||||||||||||||||
| # LSTM equations: | ||||||||||||||||||||||||||||||||||||||||
| # i_t = sigmoid(W_ii * x_t + b_ii + W_hi * h_{t-1} + b_hi) | ||||||||||||||||||||||||||||||||||||||||
| # f_t = sigmoid(W_if * x_t + b_if + W_hf * h_{t-1} + b_hf) | ||||||||||||||||||||||||||||||||||||||||
| # g_t = tanh(W_ig * x_t + b_ig + W_hg * h_{t-1} + b_hg) | ||||||||||||||||||||||||||||||||||||||||
| # o_t = sigmoid(W_io * x_t + b_io + W_ho * h_{t-1} + b_ho) | ||||||||||||||||||||||||||||||||||||||||
| # c_t = f_t * c_{t-1} + i_t * g_t | ||||||||||||||||||||||||||||||||||||||||
| # h_t = o_t * tanh(c_t) | ||||||||||||||||||||||||||||||||||||||||
| dtype = input_tensor.struct_info.dtype | ||||||||||||||||||||||||||||||||||||||||
| if params and len(params) >= 4: | ||||||||||||||||||||||||||||||||||||||||
| weight_ih = params[0] # (4 * hidden_size, input_size) | ||||||||||||||||||||||||||||||||||||||||
| weight_hh = params[1] # (4 * hidden_size, hidden_size) | ||||||||||||||||||||||||||||||||||||||||
| bias_ih = params[2] if has_biases else None # (4 * hidden_size,) | ||||||||||||||||||||||||||||||||||||||||
| bias_hh = params[3] if has_biases else None # (4 * hidden_size,) | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| # Fallback: create zero weights | ||||||||||||||||||||||||||||||||||||||||
| weight_ih = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| weight_hh = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| bias_ih = None | ||||||||||||||||||||||||||||||||||||||||
| bias_hh = None | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+291
to
+300
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating zero-tensors for weights as a fallback is problematic. This can lead to silent correctness issues where the model compiles but produces incorrect (zero) outputs. It's better to raise an error if weights are not provided, as they are essential for a functional LSTM layer.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| # Initialize hidden and cell states | ||||||||||||||||||||||||||||||||||||||||
| if hx is not None and len(hx) >= 2: | ||||||||||||||||||||||||||||||||||||||||
| h_0 = hx[0] # (num_layers, batch_size, hidden_size) | ||||||||||||||||||||||||||||||||||||||||
| c_0 = hx[1] # (num_layers, batch_size, hidden_size) | ||||||||||||||||||||||||||||||||||||||||
| # Extract the first layer's hidden state | ||||||||||||||||||||||||||||||||||||||||
| h_prev = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.take(h_0, relax.const(0, "int64"), axis=0, mode="clip") | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| c_prev = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.take(c_0, relax.const(0, "int64"), axis=0, mode="clip") | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| h_prev = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| c_prev = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| # Reshape input for processing | ||||||||||||||||||||||||||||||||||||||||
| if batch_first: | ||||||||||||||||||||||||||||||||||||||||
| # Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size) | ||||||||||||||||||||||||||||||||||||||||
| input_reshaped = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.permute_dims(input_tensor, axes=[1, 0, 2]) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| input_reshaped = input_tensor | ||||||||||||||||||||||||||||||||||||||||
| weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0])) | ||||||||||||||||||||||||||||||||||||||||
| weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0])) | ||||||||||||||||||||||||||||||||||||||||
| outputs = [] | ||||||||||||||||||||||||||||||||||||||||
| for t in range(seq_len): | ||||||||||||||||||||||||||||||||||||||||
| # Get input at time t: (batch_size, input_size) | ||||||||||||||||||||||||||||||||||||||||
| x_t = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias | ||||||||||||||||||||||||||||||||||||||||
| # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, input_size).T | ||||||||||||||||||||||||||||||||||||||||
| ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t)) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, hidden_size).T | ||||||||||||||||||||||||||||||||||||||||
| hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t)) | ||||||||||||||||||||||||||||||||||||||||
| # Add biases if present | ||||||||||||||||||||||||||||||||||||||||
| if bias_ih is not None and bias_hh is not None: | ||||||||||||||||||||||||||||||||||||||||
| gates = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| elif bias_ih is not None: | ||||||||||||||||||||||||||||||||||||||||
| gates = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| elif bias_hh is not None: | ||||||||||||||||||||||||||||||||||||||||
| gates = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+342
to
+355
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for adding biases is quite verbose with multiple
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, hidden_size) | ||||||||||||||||||||||||||||||||||||||||
| gate_size = hidden_size | ||||||||||||||||||||||||||||||||||||||||
| i_gate = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.strided_slice(gates, axes=[1], begin=[0], end=[gate_size]) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| f_gate = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.strided_slice(gates, axes=[1], begin=[gate_size], end=[2 * gate_size]) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| g_gate = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], end=[3 * gate_size]) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| o_gate = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], end=[4 * gate_size]) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+358
to
+369
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The four gates (input, forget, cell, output) are split from the concatenated gates tensor using four separate gate_tuple = self.block_builder.emit(relax.op.split(gates, 4, axis=1))
i_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 0))
f_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 1))
g_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 2))
o_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 3)) |
||||||||||||||||||||||||||||||||||||||||
| # Apply activations | ||||||||||||||||||||||||||||||||||||||||
| i_t = self.block_builder.emit(relax.op.sigmoid(i_gate)) | ||||||||||||||||||||||||||||||||||||||||
| f_t = self.block_builder.emit(relax.op.sigmoid(f_gate)) | ||||||||||||||||||||||||||||||||||||||||
| g_t = self.block_builder.emit(relax.op.tanh(g_gate)) | ||||||||||||||||||||||||||||||||||||||||
| o_t = self.block_builder.emit(relax.op.sigmoid(o_gate)) | ||||||||||||||||||||||||||||||||||||||||
| # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t | ||||||||||||||||||||||||||||||||||||||||
| c_t = self.block_builder.emit( | ||||||||||||||||||||||||||||||||||||||||
| relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t)) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| # Update hidden state: h_t = o_t * tanh(c_t) | ||||||||||||||||||||||||||||||||||||||||
| h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t))) | ||||||||||||||||||||||||||||||||||||||||
| # Store output | ||||||||||||||||||||||||||||||||||||||||
| outputs.append(h_t) | ||||||||||||||||||||||||||||||||||||||||
| # Update for next iteration | ||||||||||||||||||||||||||||||||||||||||
| h_prev = h_t | ||||||||||||||||||||||||||||||||||||||||
| c_prev = c_t | ||||||||||||||||||||||||||||||||||||||||
| # Stack outputs: (seq_len, batch_size, hidden_size) | ||||||||||||||||||||||||||||||||||||||||
| output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) | ||||||||||||||||||||||||||||||||||||||||
| # Reshape back to batch_first if needed | ||||||||||||||||||||||||||||||||||||||||
| if batch_first: | ||||||||||||||||||||||||||||||||||||||||
| # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) | ||||||||||||||||||||||||||||||||||||||||
| output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) | ||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ########## Manipulation ########## | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def _narrow(self, node: fx.Node) -> relax.Var: | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -491,6 +651,7 @@ def create_convert_map( | |||||||||||||||||||||||||||||||||||||||
| "instance_norm.default": self._instance_norm, | ||||||||||||||||||||||||||||||||||||||||
| "layer_norm.default": self._layer_norm, | ||||||||||||||||||||||||||||||||||||||||
| "linear.default": self._linear, | ||||||||||||||||||||||||||||||||||||||||
| "lstm.input": self._lstm, | ||||||||||||||||||||||||||||||||||||||||
| "max_pool1d.default": self._max_pool1d, | ||||||||||||||||||||||||||||||||||||||||
| "max_pool2d.default": self._max_pool2d, | ||||||||||||||||||||||||||||||||||||||||
| "max_pool3d.default": self._max_pool3d, | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import operator | ||
| import pytest | ||
| import torch | ||
| import numpy as np | ||
| from torch import nn | ||
| from torch.nn import Module | ||
| from torch.export import export | ||
|
|
@@ -5940,6 +5941,76 @@ def main( | |
| verify_model(MatrixMultiply(), example_args, {}, Expected) | ||
|
|
||
|
|
||
| def test_lstm(): | ||
| class BasicLSTM(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.lstm = nn.LSTM( | ||
| input_size=4, | ||
| hidden_size=8, | ||
| num_layers=1, | ||
| batch_first=True, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.lstm(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(42) | ||
| x = torch.randn(2, 3, 4, dtype=torch.float32) | ||
| model = BasicLSTM() | ||
| with torch.no_grad(): | ||
| pytorch_output = model(x) | ||
| exported_program = export(model, args=(x,)) | ||
| mod = from_exported_program(exported_program) | ||
| target = tvm.target.Target("llvm") | ||
| ex = relax.build(mod, target) | ||
| vm = relax.VirtualMachine(ex, tvm.cpu()) | ||
| x_tvm = tvm.runtime.tensor(x.numpy()) | ||
| tvm_output = vm["main"](x_tvm) | ||
| if hasattr(tvm_output, "numpy"): | ||
| tvm_output_np = tvm_output.numpy() | ||
| else: | ||
| tvm_output_np = tvm_output[0].numpy() | ||
| assert ( | ||
| pytorch_output.shape == tvm_output_np.shape | ||
| ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" | ||
| np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) | ||
|
|
||
| class SeqFirstLSTM(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.lstm = nn.LSTM( | ||
| input_size=3, | ||
| hidden_size=6, | ||
| num_layers=1, | ||
| batch_first=False, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.lstm(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(43) | ||
| x2 = torch.randn(4, 2, 3, dtype=torch.float32) | ||
| model2 = SeqFirstLSTM() | ||
| with torch.no_grad(): | ||
| pytorch_output2 = model2(x2) | ||
| exported_program2 = export(model2, args=(x2,)) | ||
| mod2 = from_exported_program(exported_program2) | ||
| ex2 = relax.build(mod2, target) | ||
| vm2 = relax.VirtualMachine(ex2, tvm.cpu()) | ||
| x2_tvm = tvm.runtime.tensor(x2.numpy()) | ||
| tvm_output2 = vm2["main"](x2_tvm) | ||
| if hasattr(tvm_output2, "numpy"): | ||
| tvm_output2_np = tvm_output2.numpy() | ||
| else: | ||
| tvm_output2_np = tvm_output2[0].numpy() | ||
| assert pytorch_output2.shape == tvm_output2_np.shape | ||
| np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) | ||
|
Comment on lines
+5944
to
+6012
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() | ||
| 1 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new logic for handling integer indexing on tensors is incorrect. It returns the tensor itself, which breaks the semantics of tensor indexing. For a tensor
x,x[0]should return the first slice along axis 0, which has a reduced rank. The current implementation returnsxunmodified. Since_getitemis a general-purpose function, this change can cause incorrect behavior for other operators that rely on it.This seems to be a workaround for an incomplete
_lstmimplementation. The correct fix should be in the_lstmoperator implementation to return a proper tuple output, and this logic should be removed from_getitem.