Skip to content

Commit aeae9b3

Browse files
committed
adding test cases and correcting empty__stride decomposition
1 parent b4913b6 commit aeae9b3

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def select_scatter_decomposition(
232232
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
233233
empty_size = args[0]
234234
empty_stride = args[1]
235-
return torch.as_strided(torch.empty(empty_size), empty_stride)
235+
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
236236

237237

238238
def get_decompositions(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch_tensorrt
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
45

56
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
@@ -868,6 +869,99 @@ def forward(self, x, src, dim, index):
868869
f"Select_scatter TRT outputs don't match with the original model.",
869870
)
870871

872+
empty_ops = [
873+
(
874+
"empty_stride_one_dimension_firstcase",
875+
[5, 5],
876+
[1, 2],
877+
None,
878+
),
879+
(
880+
"empty_stride_two_dimension_secondcase",
881+
[5, 5],
882+
[2, 2],
883+
None,
884+
),
885+
(
886+
"empty_three_dimension",
887+
[8, 8, 8],
888+
[1, 2, 3],
889+
torch.int32,
890+
),
891+
]
892+
893+
@parameterized.expand(
894+
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
895+
)
896+
def test_empty_stride(self, _, shape_or_input, stride, data_type):
897+
class TestModule(torch.nn.Module):
898+
def __init__(self):
899+
super().__init__()
900+
901+
def forward(self, input):
902+
# The add operation is added otherwise it returns an empty graph post lowering passes
903+
add_tensor = torch.ops.aten.add(input[0], input[0])
904+
shape_or_input[0] = input[0].shape[0]
905+
empty_strided = torch.ops.aten.empty_strided.default(
906+
shape_or_input, stride, dtype=data_type
907+
)
908+
add_tensor = empty_strided.cuda() + add_tensor
909+
return add_tensor
910+
911+
# Operations expected to be included in the traced graph after decompositions
912+
unexpected_ops = {
913+
torch.ops.aten.empty_strided.default,
914+
torch.ops.aten.empty_permuted.default,
915+
}
916+
expected_ops = {torch.ops.aten.add.Tensor}
917+
918+
input = [torch.randint(1, 3, shape_or_input, dtype=torch.int32).cuda()]
919+
inputs = [input]
920+
921+
fx_graph = torch.fx.symbolic_trace(TestModule())
922+
923+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
924+
fx_graph,
925+
inputs,
926+
expected_ops=expected_ops,
927+
unexpected_ops=unexpected_ops,
928+
min_block_size=2,
929+
)
930+
931+
torch._dynamo.reset()
932+
933+
self.assertEqual(
934+
len(unexpected_ops_seen),
935+
0,
936+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
937+
)
938+
939+
self.assertEqual(
940+
len(expected_ops_unseen),
941+
0,
942+
f"The following expected ops were not encountered: {expected_ops_unseen}",
943+
)
944+
945+
torch._dynamo.reset()
946+
947+
# Validate that the results between Torch and Torch-TRT are similar
948+
optimized_model = torch_tensorrt.compile(
949+
fx_graph,
950+
"torch_compile",
951+
inputs,
952+
min_block_size=1,
953+
truncate_double=True,
954+
pass_through_build_failures=True,
955+
)
956+
optimized_model_results = optimized_model(*inputs).detach().cpu()
957+
torch_model_results = fx_graph(*inputs).detach().cpu()
958+
959+
self.assertEqual(
960+
optimized_model_results.shape,
961+
torch_model_results.shape,
962+
f"The optimized model results shape and torch model results shape should be equal in empty_stride",
963+
)
964+
871965

872966
if __name__ == "__main__":
873967
run_tests()

0 commit comments

Comments
 (0)