|
1 | 1 | import torch
|
2 | 2 | import torch_tensorrt
|
| 3 | +from parameterized import parameterized |
3 | 4 | from torch.testing._internal.common_utils import TestCase, run_tests
|
| 5 | +from parameterized import parameterized |
4 | 6 |
|
5 | 7 | from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
|
6 | 8 |
|
@@ -868,6 +870,99 @@ def forward(self, x, src, dim, index):
|
868 | 870 | f"Select_scatter TRT outputs don't match with the original model.",
|
869 | 871 | )
|
870 | 872 |
|
| 873 | + empty_ops = [ |
| 874 | + ( |
| 875 | + "empty_stride_one_dimension_firstcase", |
| 876 | + [5, 5], |
| 877 | + [1, 2], |
| 878 | + None, |
| 879 | + ), |
| 880 | + ( |
| 881 | + "empty_stride_two_dimension_secondcase", |
| 882 | + [5, 5], |
| 883 | + [2, 2], |
| 884 | + None, |
| 885 | + ), |
| 886 | + ( |
| 887 | + "empty_three_dimension", |
| 888 | + [8, 8, 8], |
| 889 | + [1, 2, 3], |
| 890 | + torch.int32, |
| 891 | + ), |
| 892 | + ] |
| 893 | + |
| 894 | + @parameterized.expand( |
| 895 | + [(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops] |
| 896 | + ) |
| 897 | + def test_empty_stride(self, _, shape_or_input, stride, data_type): |
| 898 | + class TestModule(torch.nn.Module): |
| 899 | + def __init__(self): |
| 900 | + super().__init__() |
| 901 | + |
| 902 | + def forward(self, input): |
| 903 | + # The add operation is added otherwise it returns an empty graph post lowering passes |
| 904 | + add_tensor = torch.ops.aten.add(input[0], input[0]) |
| 905 | + shape_or_input[0] = input[0].shape[0] |
| 906 | + empty_strided = torch.ops.aten.empty_strided.default( |
| 907 | + shape_or_input, stride, dtype=data_type |
| 908 | + ) |
| 909 | + add_tensor = empty_strided.cuda() + add_tensor |
| 910 | + return add_tensor |
| 911 | + |
| 912 | + # Operations expected to be included in the traced graph after decompositions |
| 913 | + unexpected_ops = { |
| 914 | + torch.ops.aten.empty_strided.default, |
| 915 | + torch.ops.aten.empty_permuted.default, |
| 916 | + } |
| 917 | + expected_ops = {torch.ops.aten.add.Tensor} |
| 918 | + |
| 919 | + input = [torch.randint(1, 3, shape_or_input, dtype=torch.int32).cuda()] |
| 920 | + inputs = [input] |
| 921 | + |
| 922 | + fx_graph = torch.fx.symbolic_trace(TestModule()) |
| 923 | + |
| 924 | + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( |
| 925 | + fx_graph, |
| 926 | + inputs, |
| 927 | + expected_ops=expected_ops, |
| 928 | + unexpected_ops=unexpected_ops, |
| 929 | + min_block_size=2, |
| 930 | + ) |
| 931 | + |
| 932 | + torch._dynamo.reset() |
| 933 | + |
| 934 | + self.assertEqual( |
| 935 | + len(unexpected_ops_seen), |
| 936 | + 0, |
| 937 | + f"The following unexpected ops were encountered: {unexpected_ops_seen}", |
| 938 | + ) |
| 939 | + |
| 940 | + self.assertEqual( |
| 941 | + len(expected_ops_unseen), |
| 942 | + 0, |
| 943 | + f"The following expected ops were not encountered: {expected_ops_unseen}", |
| 944 | + ) |
| 945 | + |
| 946 | + torch._dynamo.reset() |
| 947 | + |
| 948 | + # Validate that the results between Torch and Torch-TRT are similar |
| 949 | + optimized_model = torch_tensorrt.compile( |
| 950 | + fx_graph, |
| 951 | + "torch_compile", |
| 952 | + inputs, |
| 953 | + min_block_size=1, |
| 954 | + truncate_double=True, |
| 955 | + pass_through_build_failures=True, |
| 956 | + ) |
| 957 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 958 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 959 | + |
| 960 | + self.assertEqual( |
| 961 | + optimized_model_results.shape, |
| 962 | + torch_model_results.shape, |
| 963 | + f"The optimized model results shape and torch model results shape should be equal in empty_stride", |
| 964 | + ) |
| 965 | + |
871 | 966 |
|
872 | 967 | if __name__ == "__main__":
|
873 | 968 | run_tests()
|
0 commit comments