|
1 | 1 | import torch
|
2 | 2 | import torch_tensorrt
|
| 3 | +from parameterized import parameterized |
3 | 4 | from torch.testing._internal.common_utils import TestCase, run_tests
|
4 | 5 |
|
5 | 6 | from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
|
@@ -484,6 +485,99 @@ def forward(self, x):
|
484 | 485 | f"The optimized model results shape and torch model results shape should be equal in empty_like",
|
485 | 486 | )
|
486 | 487 |
|
| 488 | + empty_ops = [ |
| 489 | + ( |
| 490 | + "empty_stride_one_dimension_firstcase", |
| 491 | + [5, 5], |
| 492 | + [1, 2], |
| 493 | + None, |
| 494 | + ), |
| 495 | + ( |
| 496 | + "empty_stride_two_dimension_secondcase", |
| 497 | + [5, 5], |
| 498 | + [2, 2], |
| 499 | + None, |
| 500 | + ), |
| 501 | + ( |
| 502 | + "empty_three_dimension", |
| 503 | + [8, 8, 8], |
| 504 | + [1, 2, 3], |
| 505 | + torch.int32, |
| 506 | + ), |
| 507 | + ] |
| 508 | + |
| 509 | + @parameterized.expand( |
| 510 | + [(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops] |
| 511 | + ) |
| 512 | + def test_empty_stride(self, _, shape_or_input, stride, data_type): |
| 513 | + class TestModule(torch.nn.Module): |
| 514 | + def __init__(self): |
| 515 | + super().__init__() |
| 516 | + |
| 517 | + def forward(self, input): |
| 518 | + # The add operation is added otherwise it returns an empty graph post lowering passes |
| 519 | + add_tensor = torch.ops.aten.add(input[0], input[0]) |
| 520 | + shape_or_input[0] = input[0].shape[0] |
| 521 | + empty_strided = torch.ops.aten.empty_strided.default( |
| 522 | + shape_or_input, stride, dtype=data_type |
| 523 | + ) |
| 524 | + add_tensor = empty_strided.cuda() + add_tensor |
| 525 | + return add_tensor |
| 526 | + |
| 527 | + # Operations expected to be included in the traced graph after decompositions |
| 528 | + unexpected_ops = { |
| 529 | + torch.ops.aten.empty_strided.default, |
| 530 | + torch.ops.aten.empty_permuted.default, |
| 531 | + } |
| 532 | + expected_ops = {torch.ops.aten.add.Tensor} |
| 533 | + |
| 534 | + input = [torch.randint(1, 3, shape_or_input, dtype=torch.int32).cuda()] |
| 535 | + inputs = [input] |
| 536 | + |
| 537 | + fx_graph = torch.fx.symbolic_trace(TestModule()) |
| 538 | + |
| 539 | + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( |
| 540 | + fx_graph, |
| 541 | + inputs, |
| 542 | + expected_ops=expected_ops, |
| 543 | + unexpected_ops=unexpected_ops, |
| 544 | + min_block_size=2, |
| 545 | + ) |
| 546 | + |
| 547 | + torch._dynamo.reset() |
| 548 | + |
| 549 | + self.assertEqual( |
| 550 | + len(unexpected_ops_seen), |
| 551 | + 0, |
| 552 | + f"The following unexpected ops were encountered: {unexpected_ops_seen}", |
| 553 | + ) |
| 554 | + |
| 555 | + self.assertEqual( |
| 556 | + len(expected_ops_unseen), |
| 557 | + 0, |
| 558 | + f"The following expected ops were not encountered: {expected_ops_unseen}", |
| 559 | + ) |
| 560 | + |
| 561 | + torch._dynamo.reset() |
| 562 | + |
| 563 | + # Validate that the results between Torch and Torch-TRT are similar |
| 564 | + optimized_model = torch_tensorrt.compile( |
| 565 | + fx_graph, |
| 566 | + "torch_compile", |
| 567 | + inputs, |
| 568 | + min_block_size=1, |
| 569 | + truncate_double=True, |
| 570 | + pass_through_build_failures=True, |
| 571 | + ) |
| 572 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 573 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 574 | + |
| 575 | + self.assertEqual( |
| 576 | + optimized_model_results.shape, |
| 577 | + torch_model_results.shape, |
| 578 | + f"The optimized model results shape and torch model results shape should be equal in empty_stride", |
| 579 | + ) |
| 580 | + |
487 | 581 |
|
488 | 582 | if __name__ == "__main__":
|
489 | 583 | run_tests()
|
0 commit comments