diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 44bc8b9445..ed0f1bb843 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -137,25 +137,23 @@ def broadcastable( "Check if two tensors are broadcastable according to torch rules" a_shape = tuple(a.shape) b_shape = tuple(b.shape) + # check from the trailing diff = len(a_shape) - len(b_shape) - if diff == 0: + + # Validate tensors have same rank and shape + if diff == 0 and all(a_shape[i] == b_shape[i] for i in range(len(a_shape))): return True + + # Left-pad the shorter dimension with ones if diff > 0: - max = len(a_shape) - min = len(b_shape) - greater_tensor = a_shape - lesser_tensor = b_shape - elif diff < 0: - max = len(b_shape) - min = len(a_shape) - greater_tensor = b_shape - lesser_tensor = a_shape - j = min - 1 - for i in range(max - 1, diff - 1, -1): - if not ( - greater_tensor[i] != lesser_tensor[j] - and (greater_tensor[i] == 1 or lesser_tensor[i] == 1) - ): + b_shape = (1,) * abs(diff) + b_shape + else: + a_shape = (1,) * abs(diff) + a_shape + + # Validate one of the following conditions for broadcastability per-dimension + # 1. Equal number of dimensions or 2. Dimension has shape 1 + for i in range(len(a_shape)): + if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1): return False return True diff --git a/tests/py/dynamo/converters/test_where_aten.py b/tests/py/dynamo/converters/test_where_aten.py index ddeb269ee9..0f6ae1818a 100644 --- a/tests/py/dynamo/converters/test_where_aten.py +++ b/tests/py/dynamo/converters/test_where_aten.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn +from harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from harness import DispatchTestCase class TestWhereConverter(DispatchTestCase): @@ -28,6 +28,20 @@ def forward(self, condition, x, y): expected_ops={torch.ops.aten.where.self}, ) + def test_0D_input(self): + class Where(nn.Module): + def forward(self, condition, x, y): + return torch.where(condition, x, y) + + inputX = torch.randn((5, 6, 7, 1, 3)) + inputOther = torch.tensor(8.0, dtype=torch.float) + condition = inputX < 0 + self.run_test( + Where(), + (condition, inputX, inputOther), + expected_ops={torch.ops.aten.where.self}, + ) + if __name__ == "__main__": run_tests()