Skip to content

Commit 1360961

Browse files
committed
adding test cases and correcting empty__stride decomposition
1 parent b4913b6 commit 1360961

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def select_scatter_decomposition(
232232
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
233233
empty_size = args[0]
234234
empty_stride = args[1]
235-
return torch.as_strided(torch.empty(empty_size), empty_stride)
235+
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
236236

237237

238238
def get_decompositions(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import torch_tensorrt
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
5+
from parameterized import parameterized
46

57
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
68

@@ -868,6 +870,99 @@ def forward(self, x, src, dim, index):
868870
f"Select_scatter TRT outputs don't match with the original model.",
869871
)
870872

873+
empty_ops = [
874+
(
875+
"empty_stride_one_dimension_firstcase",
876+
[5, 5],
877+
[1, 2],
878+
None,
879+
),
880+
(
881+
"empty_stride_two_dimension_secondcase",
882+
[5, 5],
883+
[2, 2],
884+
None,
885+
),
886+
(
887+
"empty_three_dimension",
888+
[8, 8, 8],
889+
[1, 2, 3],
890+
torch.int32,
891+
),
892+
]
893+
894+
@parameterized.expand(
895+
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
896+
)
897+
def test_empty_stride(self, _, shape_or_input, stride, data_type):
898+
class TestModule(torch.nn.Module):
899+
def __init__(self):
900+
super().__init__()
901+
902+
def forward(self, input):
903+
# The add operation is added otherwise it returns an empty graph post lowering passes
904+
add_tensor = torch.ops.aten.add(input[0], input[0])
905+
shape_or_input[0] = input[0].shape[0]
906+
empty_strided = torch.ops.aten.empty_strided.default(
907+
shape_or_input, stride, dtype=data_type
908+
)
909+
add_tensor = empty_strided.cuda() + add_tensor
910+
return add_tensor
911+
912+
# Operations expected to be included in the traced graph after decompositions
913+
unexpected_ops = {
914+
torch.ops.aten.empty_strided.default,
915+
torch.ops.aten.empty_permuted.default,
916+
}
917+
expected_ops = {torch.ops.aten.add.Tensor}
918+
919+
input = [torch.randint(1, 3, shape_or_input, dtype=torch.int32).cuda()]
920+
inputs = [input]
921+
922+
fx_graph = torch.fx.symbolic_trace(TestModule())
923+
924+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
925+
fx_graph,
926+
inputs,
927+
expected_ops=expected_ops,
928+
unexpected_ops=unexpected_ops,
929+
min_block_size=2,
930+
)
931+
932+
torch._dynamo.reset()
933+
934+
self.assertEqual(
935+
len(unexpected_ops_seen),
936+
0,
937+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
938+
)
939+
940+
self.assertEqual(
941+
len(expected_ops_unseen),
942+
0,
943+
f"The following expected ops were not encountered: {expected_ops_unseen}",
944+
)
945+
946+
torch._dynamo.reset()
947+
948+
# Validate that the results between Torch and Torch-TRT are similar
949+
optimized_model = torch_tensorrt.compile(
950+
fx_graph,
951+
"torch_compile",
952+
inputs,
953+
min_block_size=1,
954+
truncate_double=True,
955+
pass_through_build_failures=True,
956+
)
957+
optimized_model_results = optimized_model(*inputs).detach().cpu()
958+
torch_model_results = fx_graph(*inputs).detach().cpu()
959+
960+
self.assertEqual(
961+
optimized_model_results.shape,
962+
torch_model_results.shape,
963+
f"The optimized model results shape and torch model results shape should be equal in empty_stride",
964+
)
965+
871966

872967
if __name__ == "__main__":
873968
run_tests()

0 commit comments

Comments
 (0)