From a280055e94ac0d6412ba1b688617cbe15792c745 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 23 Jul 2024 10:40:39 -0700 Subject: [PATCH 1/3] add dynamic support for embedding_bag/broadcast_in_dim/index_select --- .../dynamo/conversion/test_unsqueeze_aten.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/test_unsqueeze_aten.py b/tests/py/dynamo/conversion/test_unsqueeze_aten.py index 87375afdec..fbd5042a0b 100644 --- a/tests/py/dynamo/conversion/test_unsqueeze_aten.py +++ b/tests/py/dynamo/conversion/test_unsqueeze_aten.py @@ -4,7 +4,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException +from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException, impl from .harness import DispatchTestCase @@ -89,6 +89,46 @@ def forward(self, x): inputs, ) + +class TestBroadcastInDim(DispatchTestCase): + def test_broadcast_in_dim_supported_1( + self, + ): + class Unsqueeze(nn.Module): + def forward(self, x): + return torch.ops.prims.broadcast_in_dim.default(x, [3, 4, 1], [0, 1]) + + inputs = [torch.arange(0, 12).reshape(3, 4)] + self.run_test( + Unsqueeze(), + inputs, + ) + + # TODO: need help from Dheeraj to figure out why this test is failed + def test_broadcast_in_dim_with_dynamic_shape( + self, + ): + class BroadcastInDim(nn.Module): + def forward(self, x): + dims = [0, 1] + shape = [x.shape[d] for d in dims] + shape.append(1) + return torch.ops.prims.broadcast_in_dim.default(x, shape, dims) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=(2, 3), + opt_shape=(3, 4), + max_shape=(4, 5), + ), + ] + self.run_test_with_dynamic_shape( + BroadcastInDim(), + input_specs, + use_dynamo_tracer=True, + ) + def test_broadcast_in_dim_supported_singleton( self, ): From 83a61fcd44fdad6a73a042cbd53a5fc093f2a700 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 24 Jul 2024 10:27:13 -0700 Subject: [PATCH 2/3] add testcase --- .../dynamo/conversion/aten_ops_converters.py | 12 ++- .../conversion/test_embedding_bag_aten.py | 82 ++++++++++++++++++ .../conversion/test_index_select_aten.py | 84 ++++++++++++++++++- .../dynamo/conversion/test_unsqueeze_aten.py | 42 +--------- 4 files changed, 175 insertions(+), 45 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 3da1b09fba..10bfd852af 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -283,10 +283,14 @@ def embedding_bag_validator(node: Node) -> bool: @dynamo_tensorrt_converter( - torch.ops.aten.embedding_bag.default, capability_validator=embedding_bag_validator + torch.ops.aten.embedding_bag.default, + capability_validator=embedding_bag_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten._embedding_bag.default, capability_validator=embedding_bag_validator + torch.ops.aten._embedding_bag.default, + capability_validator=embedding_bag_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { @@ -3370,7 +3374,9 @@ def aten_ops_roll( ) -@dynamo_tensorrt_converter(torch.ops.aten.index_select.default) +@dynamo_tensorrt_converter( + torch.ops.aten.index_select.default, supports_dynamic_shapes=True +) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 9664e1be58..6737b6a34e 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -1,6 +1,8 @@ import torch +import torch_tensorrt from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -408,6 +410,86 @@ def forward(self, weight, indices, offsets): propagate_shapes=True, ) + @parameterized.expand( + [ + param( + # 1d_indices_mode_0_with_per_sample_weights + weights=torch.randn((5, 2), dtype=torch.float32), + dynamic_shapes={ + "weights": {0: torch.export.Dim("dyn_dim", min=2, max=6)}, + "indices": {}, + "offsets": {}, + }, + # weights_max_dim=6, + indices=torch.tensor([1, 2, 4], dtype=torch.int32), + offsets=torch.tensor([0, 2, 3], dtype=torch.int32), + mode=0, + per_sample_weights=torch.randn((3,), dtype=torch.float32), + ), + param( + # 1d_indices_mode_1_without_per_sample_weights + weights=torch.randn((5, 2), dtype=torch.float32), + dynamic_shapes={ + "weights": { + 0: torch.export.Dim("dyn_dim", min=2, max=8), + 1: torch.export.Dim("dyn_dim_1", min=1, max=3), + }, + "indices": {}, + "offsets": {}, + }, + indices=torch.tensor([1, 2, 4, 2, 3, 4], dtype=torch.int32), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + mode=1, + per_sample_weights=None, + ), + ] + ) + def test_embedding_bag_with_weights_dynamic_shape( + self, weights, dynamic_shapes, indices, offsets, mode, per_sample_weights + ): + class EmbeddingBag(torch.nn.Module): + def forward(self, weights, indices, offsets, per_sample_weights=None): + return torch.ops.aten._embedding_bag.default( + weight=weights, + indices=indices, + offsets=offsets, + per_sample_weights=per_sample_weights, + scale_grad_by_freq=False, + mode=mode, + sparse=False, + include_last_offset=False, + padding_idx=-1, + ) + + if per_sample_weights is None: + inputs = (weights, indices, offsets) + else: + inputs = (weights, indices, offsets, per_sample_weights) + mod = EmbeddingBag() + + if per_sample_weights is not None: + dynamic_shapes["per_sample_weights"] = {} + fx_mod = torch.export.export(mod, inputs, dynamic_shapes=dynamic_shapes) + trt_mod = torch_tensorrt.dynamo.compile( + fx_mod, inputs=inputs, enable_precisions=torch.float32, min_block_size=1 + ) + + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + ref_outputs = mod(*cuda_inputs) + outputs = trt_mod(*cuda_inputs) + for out, ref in zip(outputs, ref_outputs): + torch.testing.assert_close( + out, + ref, + rtol=0.001, + atol=0.001, + equal_nan=True, + check_dtype=True, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py index 83eaedb944..8094fbecbd 100644 --- a/tests/py/dynamo/conversion/test_index_select_aten.py +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -1,7 +1,9 @@ import torch import torch.nn as nn -from parameterized import parameterized +import torch_tensorrt +from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -36,6 +38,86 @@ def forward(self, source_tensor, indices_tensor): input, ) + @parameterized.expand( + [ + param( + # 1d_source_tensor + source_tensor=torch.randn((3,), dtype=torch.float32), + dynamic_shapes={ + "source_tensor": {0: torch.export.Dim("dyn_dim", min=3, max=6)}, + "indice_tensor": {}, + }, + dim=0, + indice_tensor=torch.tensor( + [ + 1, + ], + dtype=torch.int32, + ), + ), + param( + # 2d_source_tensor + source_tensor=torch.randn((3, 3), dtype=torch.float32), + dynamic_shapes={ + "source_tensor": { + 0: torch.export.Dim("dyn_dim1", min=3, max=6), + 1: torch.export.Dim("dyn_dim2", min=2, max=7), + }, + "indice_tensor": {}, + }, + dim=-1, + indice_tensor=torch.tensor([0, 2], dtype=torch.int32), + ), + param( + # 3d_source_tensor + source_tensor=torch.randn((3, 4, 2), dtype=torch.float32), + dynamic_shapes={ + "source_tensor": { + 0: torch.export.Dim("dyn_dim1", min=3, max=6), + 1: torch.export.Dim("dyn_dim2", min=2, max=7), + }, + "indice_tensor": {}, + }, + dim=-2, + indice_tensor=torch.tensor([0, 0, 2], dtype=torch.int32), + ), + ] + ) + def test_index_select_dynamic_shape( + self, source_tensor, dynamic_shapes, dim, indice_tensor + ): + class IndexSelect(torch.nn.Module): + def forward(self, source_tensor, indice_tensor): + return torch.ops.aten.index_select.default( + source_tensor, + dim, + indice_tensor, + ) + + inputs = (source_tensor, indice_tensor) + mod = IndexSelect() + + fx_mod = torch.export.export(mod, inputs, dynamic_shapes=dynamic_shapes) + trt_mod = torch_tensorrt.dynamo.compile( + fx_mod, inputs=inputs, enable_precisions=torch.float32, min_block_size=1 + ) + + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + ref_outputs = mod(*cuda_inputs) + outputs = trt_mod(*cuda_inputs) + for out, ref in zip(outputs, ref_outputs): + torch.testing.assert_close( + out, + ref, + rtol=0.001, + atol=0.001, + equal_nan=True, + check_dtype=True, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_unsqueeze_aten.py b/tests/py/dynamo/conversion/test_unsqueeze_aten.py index fbd5042a0b..87375afdec 100644 --- a/tests/py/dynamo/conversion/test_unsqueeze_aten.py +++ b/tests/py/dynamo/conversion/test_unsqueeze_aten.py @@ -4,7 +4,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException, impl +from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException from .harness import DispatchTestCase @@ -89,46 +89,6 @@ def forward(self, x): inputs, ) - -class TestBroadcastInDim(DispatchTestCase): - def test_broadcast_in_dim_supported_1( - self, - ): - class Unsqueeze(nn.Module): - def forward(self, x): - return torch.ops.prims.broadcast_in_dim.default(x, [3, 4, 1], [0, 1]) - - inputs = [torch.arange(0, 12).reshape(3, 4)] - self.run_test( - Unsqueeze(), - inputs, - ) - - # TODO: need help from Dheeraj to figure out why this test is failed - def test_broadcast_in_dim_with_dynamic_shape( - self, - ): - class BroadcastInDim(nn.Module): - def forward(self, x): - dims = [0, 1] - shape = [x.shape[d] for d in dims] - shape.append(1) - return torch.ops.prims.broadcast_in_dim.default(x, shape, dims) - - input_specs = [ - Input( - dtype=torch.float32, - min_shape=(2, 3), - opt_shape=(3, 4), - max_shape=(4, 5), - ), - ] - self.run_test_with_dynamic_shape( - BroadcastInDim(), - input_specs, - use_dynamo_tracer=True, - ) - def test_broadcast_in_dim_supported_singleton( self, ): From 336391c34591dcfbead951c1d3949cb9c84a727e Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 24 Jul 2024 17:43:40 -0700 Subject: [PATCH 3/3] resolve comments --- .../conversion/test_embedding_bag_aten.py | 21 +++++++++++++++++-- .../conversion/test_index_select_aten.py | 14 +++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 6737b6a34e..0b06e10871 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -414,13 +414,15 @@ def forward(self, weight, indices, offsets): [ param( # 1d_indices_mode_0_with_per_sample_weights + # weights is for compile weights=torch.randn((5, 2), dtype=torch.float32), + # weights_1 is for inference + weights_1=torch.randn((6, 2), dtype=torch.float32), dynamic_shapes={ "weights": {0: torch.export.Dim("dyn_dim", min=2, max=6)}, "indices": {}, "offsets": {}, }, - # weights_max_dim=6, indices=torch.tensor([1, 2, 4], dtype=torch.int32), offsets=torch.tensor([0, 2, 3], dtype=torch.int32), mode=0, @@ -428,7 +430,10 @@ def forward(self, weight, indices, offsets): ), param( # 1d_indices_mode_1_without_per_sample_weights + # weights is for compile weights=torch.randn((5, 2), dtype=torch.float32), + # weights_1 is for inference + weights_1=torch.randn((6, 3), dtype=torch.float32), dynamic_shapes={ "weights": { 0: torch.export.Dim("dyn_dim", min=2, max=8), @@ -445,7 +450,14 @@ def forward(self, weight, indices, offsets): ] ) def test_embedding_bag_with_weights_dynamic_shape( - self, weights, dynamic_shapes, indices, offsets, mode, per_sample_weights + self, + weights, + weights_1, + dynamic_shapes, + indices, + offsets, + mode, + per_sample_weights, ): class EmbeddingBag(torch.nn.Module): def forward(self, weights, indices, offsets, per_sample_weights=None): @@ -473,6 +485,11 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): trt_mod = torch_tensorrt.dynamo.compile( fx_mod, inputs=inputs, enable_precisions=torch.float32, min_block_size=1 ) + # use the inputs with different shape to inference: + if per_sample_weights is None: + inputs = (weights_1, indices, offsets) + else: + inputs = (weights_1, indices, offsets, per_sample_weights) with torch.no_grad(): cuda_inputs = [] diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py index 8094fbecbd..c5feca013e 100644 --- a/tests/py/dynamo/conversion/test_index_select_aten.py +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -42,7 +42,10 @@ def forward(self, source_tensor, indices_tensor): [ param( # 1d_source_tensor + # source_tensor is for compile source_tensor=torch.randn((3,), dtype=torch.float32), + # source_tensor_1 is for inference + source_tensor_1=torch.randn((5,), dtype=torch.float32), dynamic_shapes={ "source_tensor": {0: torch.export.Dim("dyn_dim", min=3, max=6)}, "indice_tensor": {}, @@ -57,7 +60,10 @@ def forward(self, source_tensor, indices_tensor): ), param( # 2d_source_tensor + # source_tensor is for compile source_tensor=torch.randn((3, 3), dtype=torch.float32), + # source_tensor_1 is for inference + source_tensor_1=torch.randn((4, 6), dtype=torch.float32), dynamic_shapes={ "source_tensor": { 0: torch.export.Dim("dyn_dim1", min=3, max=6), @@ -70,7 +76,10 @@ def forward(self, source_tensor, indices_tensor): ), param( # 3d_source_tensor + # source_tensor is for compile source_tensor=torch.randn((3, 4, 2), dtype=torch.float32), + # source_tensor_1 is for inference + source_tensor_1=torch.randn((6, 7, 2), dtype=torch.float32), dynamic_shapes={ "source_tensor": { 0: torch.export.Dim("dyn_dim1", min=3, max=6), @@ -84,7 +93,7 @@ def forward(self, source_tensor, indices_tensor): ] ) def test_index_select_dynamic_shape( - self, source_tensor, dynamic_shapes, dim, indice_tensor + self, source_tensor, source_tensor_1, dynamic_shapes, dim, indice_tensor ): class IndexSelect(torch.nn.Module): def forward(self, source_tensor, indice_tensor): @@ -101,7 +110,8 @@ def forward(self, source_tensor, indice_tensor): trt_mod = torch_tensorrt.dynamo.compile( fx_mod, inputs=inputs, enable_precisions=torch.float32, min_block_size=1 ) - + # use different shape of inputs for inference: + inputs = (source_tensor_1, indice_tensor) with torch.no_grad(): cuda_inputs = [] for i in inputs: