-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Relax][PyTorch] Support lstm op for ExportedProgram importer #18346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @tlopex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the PyTorch ExportedProgram importer by adding robust support for Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the lstm.input operator from PyTorch's ExportedProgram format. The implementation includes a new _lstm method in the ExportedProgramImporter and corresponding tests. While the core LSTM logic is present, there are several areas for improvement regarding correctness, robustness, and code quality.
My review highlights a critical bug in the _getitem implementation, several high-severity issues in the _lstm implementation related to incomplete functionality and lack of robustness, and some medium-severity suggestions for code refactoring to improve maintainability and efficiency. Addressing these points will make the LSTM support more robust and reliable.
| if not isinstance(node.args[1], (list, tuple)): | ||
| indices = [node.args[1]] | ||
| else: | ||
| indices = node.args[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new logic for handling integer indexing on tensors is incorrect. It returns the tensor itself, which breaks the semantics of tensor indexing. For a tensor x, x[0] should return the first slice along axis 0, which has a reduced rank. The current implementation returns x unmodified. Since _getitem is a general-purpose function, this change can cause incorrect behavior for other operators that rely on it.
This seems to be a workaround for an incomplete _lstm implementation. The correct fix should be in the _lstm operator implementation to return a proper tuple output, and this logic should be removed from _getitem.
| else: | ||
| # Fallback to a default hidden size | ||
| hidden_size = 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fallback logic for when LSTM parameters are not provided is problematic. It defaults to hidden_size = 16. This can lead to silent correctness issues and hard-to-debug errors. It would be better to raise a ValueError if the parameters are not available to determine hidden_size, as a valid LSTM layer must have weights.
| else: | |
| # Fallback to a default hidden size | |
| hidden_size = 16 | |
| else: | |
| raise ValueError("Cannot determine hidden_size. LSTM params (weights) are required.") |
| else: | ||
| # Fallback: create zero weights | ||
| weight_ih = self.block_builder.emit( | ||
| relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) | ||
| ) | ||
| weight_hh = self.block_builder.emit( | ||
| relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) | ||
| ) | ||
| bias_ih = None | ||
| bias_hh = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating zero-tensors for weights as a fallback is problematic. This can lead to silent correctness issues where the model compiles but produces incorrect (zero) outputs. It's better to raise an error if weights are not provided, as they are essential for a functional LSTM layer.
| else: | |
| # Fallback: create zero weights | |
| weight_ih = self.block_builder.emit( | |
| relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) | |
| ) | |
| weight_hh = self.block_builder.emit( | |
| relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) | |
| ) | |
| bias_ih = None | |
| bias_hh = None | |
| else: | |
| raise ValueError("LSTM params (weights) are required.") |
| if batch_first: | ||
| # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) | ||
| output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) | ||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _lstm implementation is incomplete. It only returns the output sequence but not the final hidden and cell states, which are part of the standard torch.nn.LSTM output (output, (h_n, c_n)). This will lead to incorrect behavior for models that use these states. The function should be updated to return a tuple containing the output sequence and the final hidden/cell states to fully match the PyTorch operator's behavior.
| if bias_ih is not None and bias_hh is not None: | ||
| gates = self.block_builder.emit( | ||
| relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh) | ||
| ) | ||
| elif bias_ih is not None: | ||
| gates = self.block_builder.emit( | ||
| relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates) | ||
| ) | ||
| elif bias_hh is not None: | ||
| gates = self.block_builder.emit( | ||
| relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh) | ||
| ) | ||
| else: | ||
| gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for adding biases is quite verbose with multiple if/elif/else branches. This can be simplified for better readability and maintainability. You can calculate the total gates first, and then conditionally add the biases.
| if bias_ih is not None and bias_hh is not None: | |
| gates = self.block_builder.emit( | |
| relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh) | |
| ) | |
| elif bias_ih is not None: | |
| gates = self.block_builder.emit( | |
| relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates) | |
| ) | |
| elif bias_hh is not None: | |
| gates = self.block_builder.emit( | |
| relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh) | |
| ) | |
| else: | |
| gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) | |
| gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) | |
| if bias_ih is not None: | |
| gates = self.block_builder.emit(relax.op.add(gates, bias_ih)) | |
| if bias_hh is not None: | |
| gates = self.block_builder.emit(relax.op.add(gates, bias_hh)) |
| i_gate = self.block_builder.emit( | ||
| relax.op.strided_slice(gates, axes=[1], begin=[0], end=[gate_size]) | ||
| ) | ||
| f_gate = self.block_builder.emit( | ||
| relax.op.strided_slice(gates, axes=[1], begin=[gate_size], end=[2 * gate_size]) | ||
| ) | ||
| g_gate = self.block_builder.emit( | ||
| relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], end=[3 * gate_size]) | ||
| ) | ||
| o_gate = self.block_builder.emit( | ||
| relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], end=[4 * gate_size]) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The four gates (input, forget, cell, output) are split from the concatenated gates tensor using four separate strided_slice operations. This can be done more efficiently and concisely using a single relax.op.split operation, which would also improve readability.
gate_tuple = self.block_builder.emit(relax.op.split(gates, 4, axis=1))
i_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 0))
f_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 1))
g_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 2))
o_gate = self.block_builder.emit(relax.TupleGetItem(gate_tuple, 3))| def test_lstm(): | ||
| class BasicLSTM(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.lstm = nn.LSTM( | ||
| input_size=4, | ||
| hidden_size=8, | ||
| num_layers=1, | ||
| batch_first=True, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.lstm(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(42) | ||
| x = torch.randn(2, 3, 4, dtype=torch.float32) | ||
| model = BasicLSTM() | ||
| with torch.no_grad(): | ||
| pytorch_output = model(x) | ||
| exported_program = export(model, args=(x,)) | ||
| mod = from_exported_program(exported_program) | ||
| target = tvm.target.Target("llvm") | ||
| ex = relax.build(mod, target) | ||
| vm = relax.VirtualMachine(ex, tvm.cpu()) | ||
| x_tvm = tvm.runtime.tensor(x.numpy()) | ||
| tvm_output = vm["main"](x_tvm) | ||
| if hasattr(tvm_output, "numpy"): | ||
| tvm_output_np = tvm_output.numpy() | ||
| else: | ||
| tvm_output_np = tvm_output[0].numpy() | ||
| assert ( | ||
| pytorch_output.shape == tvm_output_np.shape | ||
| ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" | ||
| np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) | ||
|
|
||
| class SeqFirstLSTM(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.lstm = nn.LSTM( | ||
| input_size=3, | ||
| hidden_size=6, | ||
| num_layers=1, | ||
| batch_first=False, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.lstm(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(43) | ||
| x2 = torch.randn(4, 2, 3, dtype=torch.float32) | ||
| model2 = SeqFirstLSTM() | ||
| with torch.no_grad(): | ||
| pytorch_output2 = model2(x2) | ||
| exported_program2 = export(model2, args=(x2,)) | ||
| mod2 = from_exported_program(exported_program2) | ||
| ex2 = relax.build(mod2, target) | ||
| vm2 = relax.VirtualMachine(ex2, tvm.cpu()) | ||
| x2_tvm = tvm.runtime.tensor(x2.numpy()) | ||
| tvm_output2 = vm2["main"](x2_tvm) | ||
| if hasattr(tvm_output2, "numpy"): | ||
| tvm_output2_np = tvm_output2.numpy() | ||
| else: | ||
| tvm_output2_np = tvm_output2[0].numpy() | ||
| assert pytorch_output2.shape == tvm_output2_np.shape | ||
| np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test_lstm function contains a significant amount of duplicated code for testing the batch_first=True and batch_first=False cases. This can be refactored into a helper function to improve readability and maintainability. The helper function could take the model and input tensor as arguments and perform the verification logic.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This pr supports
lstm.inputfor ExportedProgram importer.This links to issue #18340