|
18 | 18 | import pytest |
19 | 19 |
|
20 | 20 | import tvm |
21 | | -from tvm import relax |
22 | | -from tvm.script import tir as T, relax as R, ir as I |
23 | 21 | import tvm.testing |
| 22 | +from tvm import relax |
| 23 | +from tvm.script import ir as I |
| 24 | +from tvm.script import relax as R |
| 25 | +from tvm.script import tir as T |
24 | 26 |
|
25 | 27 |
|
26 | 28 | class BaseCompare(tvm.testing.CompareBeforeAfter): |
@@ -704,5 +706,56 @@ def main(): |
704 | 706 | tvm.ir.assert_structural_equal(Before, AfterWhenDisabled) |
705 | 707 |
|
706 | 708 |
|
| 709 | +def test_static_args(): |
| 710 | + @I.ir_module |
| 711 | + class Before: |
| 712 | + @R.function |
| 713 | + def main(): |
| 714 | + storage0 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") |
| 715 | + alloc0 = R.memory.alloc_tensor(storage0, 0, R.shape([8]), "float32") |
| 716 | + _ = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string")) |
| 717 | + return R.tuple() |
| 718 | + |
| 719 | + @I.ir_module |
| 720 | + class Expected: |
| 721 | + @R.function(private=True) |
| 722 | + def cuda_graph_alloc() -> R.Tuple(R.Object): |
| 723 | + R.func_attr({"relax.force_pure": True}) |
| 724 | + storage0: R.Object = R.memory.alloc_storage( |
| 725 | + R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32") |
| 726 | + ) |
| 727 | + gv: R.Tuple(R.Object) = (storage0,) |
| 728 | + return gv |
| 729 | + |
| 730 | + @R.function(private=True) |
| 731 | + def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple: |
| 732 | + R.func_attr({"relax.force_pure": True}) |
| 733 | + _: R.Object = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string")) |
| 734 | + gv: R.Tuple = R.tuple() |
| 735 | + return gv |
| 736 | + |
| 737 | + @R.function |
| 738 | + def main() -> R.Tuple: |
| 739 | + cls = Expected |
| 740 | + gv: R.Tuple(R.Object) = R.call_builtin_with_ctx( |
| 741 | + "vm.builtin.cuda_graph.get_cached_alloc", |
| 742 | + (cls.cuda_graph_alloc, R.prim_value(0)), |
| 743 | + sinfo_args=(R.Tuple(R.Object),), |
| 744 | + ) |
| 745 | + storage0: R.Object = gv[0] |
| 746 | + alloc0: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( |
| 747 | + storage0, R.prim_value(0), R.shape([8]), R.dtype("float32") |
| 748 | + ) |
| 749 | + gv1: R.Tuple = R.call_builtin_with_ctx( |
| 750 | + "vm.builtin.cuda_graph.run_or_capture", |
| 751 | + (cls.cuda_graph_capture, (alloc0,), R.prim_value(0)), |
| 752 | + sinfo_args=(R.Tuple,), |
| 753 | + ) |
| 754 | + return R.tuple() |
| 755 | + |
| 756 | + mod = relax.transform.RewriteCUDAGraph()(Before) |
| 757 | + tvm.ir.assert_structural_equal(mod, Expected) |
| 758 | + |
| 759 | + |
707 | 760 | if __name__ == "__main__": |
708 | 761 | tvm.testing.main() |
0 commit comments