Skip to content

Commit 95f97e8

Browse files
authored
[Relax] CUDA graph rewrite treating StringImm as static (#16691)
The RewriteCUDAGraph pass missed to consider StringImm as a static expression, causing some loss of CUDA graph rewrite opportunities. This PR fixes the issue.
1 parent 254e90a commit 95f97e8

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

src/relax/transform/rewrite_cuda_graph.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
348348
}
349349

350350
bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector = nullptr) {
351-
if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>()) {
351+
if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>() ||
352+
expr->IsInstance<StringImmNode>()) {
352353
return true;
353354
}
354355
if (const auto* prim_value = expr.as<PrimValueNode>()) {

tests/python/relax/test_transform_rewrite_cuda_graph.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
import pytest
1919

2020
import tvm
21-
from tvm import relax
22-
from tvm.script import tir as T, relax as R, ir as I
2321
import tvm.testing
22+
from tvm import relax
23+
from tvm.script import ir as I
24+
from tvm.script import relax as R
25+
from tvm.script import tir as T
2426

2527

2628
class BaseCompare(tvm.testing.CompareBeforeAfter):
@@ -704,5 +706,56 @@ def main():
704706
tvm.ir.assert_structural_equal(Before, AfterWhenDisabled)
705707

706708

709+
def test_static_args():
710+
@I.ir_module
711+
class Before:
712+
@R.function
713+
def main():
714+
storage0 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32")
715+
alloc0 = R.memory.alloc_tensor(storage0, 0, R.shape([8]), "float32")
716+
_ = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string"))
717+
return R.tuple()
718+
719+
@I.ir_module
720+
class Expected:
721+
@R.function(private=True)
722+
def cuda_graph_alloc() -> R.Tuple(R.Object):
723+
R.func_attr({"relax.force_pure": True})
724+
storage0: R.Object = R.memory.alloc_storage(
725+
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32")
726+
)
727+
gv: R.Tuple(R.Object) = (storage0,)
728+
return gv
729+
730+
@R.function(private=True)
731+
def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple:
732+
R.func_attr({"relax.force_pure": True})
733+
_: R.Object = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string"))
734+
gv: R.Tuple = R.tuple()
735+
return gv
736+
737+
@R.function
738+
def main() -> R.Tuple:
739+
cls = Expected
740+
gv: R.Tuple(R.Object) = R.call_builtin_with_ctx(
741+
"vm.builtin.cuda_graph.get_cached_alloc",
742+
(cls.cuda_graph_alloc, R.prim_value(0)),
743+
sinfo_args=(R.Tuple(R.Object),),
744+
)
745+
storage0: R.Object = gv[0]
746+
alloc0: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
747+
storage0, R.prim_value(0), R.shape([8]), R.dtype("float32")
748+
)
749+
gv1: R.Tuple = R.call_builtin_with_ctx(
750+
"vm.builtin.cuda_graph.run_or_capture",
751+
(cls.cuda_graph_capture, (alloc0,), R.prim_value(0)),
752+
sinfo_args=(R.Tuple,),
753+
)
754+
return R.tuple()
755+
756+
mod = relax.transform.RewriteCUDAGraph()(Before)
757+
tvm.ir.assert_structural_equal(mod, Expected)
758+
759+
707760
if __name__ == "__main__":
708761
tvm.testing.main()

0 commit comments

Comments
 (0)