Skip to content

Commit 41f0324

Browse files
committed
adding test cases and correcting empty__stride decomposition
1 parent 50079bd commit 41f0324

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
180180
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
181181
empty_size = args[0]
182182
empty_stride = args[1]
183-
return torch.as_strided(torch.empty(empty_size), empty_stride)
183+
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
184184

185185

186186
def get_decompositions(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch_tensorrt
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
45

56
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
@@ -484,6 +485,99 @@ def forward(self, x):
484485
f"The optimized model results shape and torch model results shape should be equal in empty_like",
485486
)
486487

488+
empty_ops = [
489+
(
490+
"empty_stride_one_dimension_firstcase",
491+
[5, 5],
492+
[1, 2],
493+
None,
494+
),
495+
(
496+
"empty_stride_two_dimension_secondcase",
497+
[5, 5],
498+
[2, 2],
499+
None,
500+
),
501+
(
502+
"empty_three_dimension",
503+
[8, 8, 8],
504+
[1, 2, 3],
505+
torch.int32,
506+
),
507+
]
508+
509+
@parameterized.expand(
510+
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
511+
)
512+
def test_empty_stride(self, _, shape_or_input, stride, data_type):
513+
class TestModule(torch.nn.Module):
514+
def __init__(self):
515+
super().__init__()
516+
517+
def forward(self, input):
518+
# The add operation is added otherwise it returns an empty graph post lowering passes
519+
add_tensor = torch.ops.aten.add(input[0], input[0])
520+
shape_or_input[0] = input[0].shape[0]
521+
empty_strided = torch.ops.aten.empty_strided.default(
522+
shape_or_input, stride, dtype=data_type
523+
)
524+
add_tensor = empty_strided.cuda() + add_tensor
525+
return add_tensor
526+
527+
# Operations expected to be included in the traced graph after decompositions
528+
unexpected_ops = {
529+
torch.ops.aten.empty_strided.default,
530+
torch.ops.aten.empty_permuted.default,
531+
}
532+
expected_ops = {torch.ops.aten.add.Tensor}
533+
534+
input = [torch.randint(1, 3, shape_or_input, dtype=torch.int32).cuda()]
535+
inputs = [input]
536+
537+
fx_graph = torch.fx.symbolic_trace(TestModule())
538+
539+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
540+
fx_graph,
541+
inputs,
542+
expected_ops=expected_ops,
543+
unexpected_ops=unexpected_ops,
544+
min_block_size=2,
545+
)
546+
547+
torch._dynamo.reset()
548+
549+
self.assertEqual(
550+
len(unexpected_ops_seen),
551+
0,
552+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
553+
)
554+
555+
self.assertEqual(
556+
len(expected_ops_unseen),
557+
0,
558+
f"The following expected ops were not encountered: {expected_ops_unseen}",
559+
)
560+
561+
torch._dynamo.reset()
562+
563+
# Validate that the results between Torch and Torch-TRT are similar
564+
optimized_model = torch_tensorrt.compile(
565+
fx_graph,
566+
"torch_compile",
567+
inputs,
568+
min_block_size=1,
569+
truncate_double=True,
570+
pass_through_build_failures=True,
571+
)
572+
optimized_model_results = optimized_model(*inputs).detach().cpu()
573+
torch_model_results = fx_graph(*inputs).detach().cpu()
574+
575+
self.assertEqual(
576+
optimized_model_results.shape,
577+
torch_model_results.shape,
578+
f"The optimized model results shape and torch model results shape should be equal in empty_stride",
579+
)
580+
487581

488582
if __name__ == "__main__":
489583
run_tests()

0 commit comments

Comments
 (0)