diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 74e773abe7e7..e48c1856f9fe 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -50,6 +50,13 @@ namespace relax { TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds, const tvm::Map& symbolic_var_map = {}); +/*! + * \brief Bind the symbolic variables to a StructInfo. This is a helper function usually called by + * other pass functions to help optimizations. + */ +TVM_DLL StructInfo Bind(const StructInfo& sinfo, + const tvm::Map& symbolic_var_map); + /*! * \brief Infer a binding map for symbolic variables * diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index b67a638dd6af..25b229ebce57 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -53,6 +53,8 @@ #include #include #include +#include +#include #include "../../support/arena.h" #include "../../support/ordered_set.h" @@ -82,6 +84,8 @@ struct LiftedFunctionRewritePlan { std::vector outputs; // The corresponding binding vars in the original function of the inputs of the lifted function std::vector inputs; + // The tir vars in the original function that are propagated to the lifted function + Optional propogated_tir_vars = NullOpt; }; /*! \brief Builder of the lifted function for cuda graph capturing or allocations */ @@ -98,6 +102,11 @@ class FuncBuilder : public ExprMutator { * \param var The variable to mark as input */ void MarkInput(const VarNode* var) { inputs_.push_back(var); } + /*! + * \brief Mark a TIR variable as the ShapeExpr input of the new function. + * \param var The variable to mark as input + */ + void MarkShapeExprInput(const tir::VarNode* var) { shape_expr_inputs_.push_back(var); } /*! * \brief Mark a variable as the output of the new function. The variable must be the LHS of an * existing binding in the new function. @@ -111,12 +120,27 @@ class FuncBuilder : public ExprMutator { /*! \brief Build the new function */ Function Build() { Array params; + Optional shape_expr = NullOpt; + if (shape_expr_inputs_.size()) { + Array tir_vars; + for (const auto* var : shape_expr_inputs_) { + auto new_var = GetRef(var).copy_with_suffix(""); + tir_var_remap_.Set(GetRef(var), new_var); + tir_vars.push_back(new_var); + } + shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars)); + } // Set up the parameters for (const auto* input : inputs_) { - auto new_var = Var(input->name_hint(), Downcast>(input->struct_info_)); + auto new_var = Var( + input->name_hint(), + VisitExprDepStructInfoField(Downcast>(input->struct_info_).value())); var_remap_[input->vid] = new_var; params.push_back(new_var); } + if (shape_expr) { + params.push_back(shape_expr.value()); + } // Emit the function body builder_->BeginBindingBlock(); for (const auto* binding : bindings_) { @@ -137,9 +161,13 @@ class FuncBuilder : public ExprMutator { return func; } + PrimExpr VisitPrimExpr(const PrimExpr& expr) { return tir::Substitute(expr, tir_var_remap_); } + support::OrderedSet inputs_; support::OrderedSet outputs_; + support::OrderedSet shape_expr_inputs_; std::vector bindings_; + Map tir_var_remap_; }; /*! @@ -159,6 +187,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { static_vars_.insert(func->params[i].get()); } } + CollectSymbolicVarHints(func); VisitExpr(func); } } @@ -174,6 +203,13 @@ class CUDAGraphRewritePlanner : public ExprVisitor { for (const auto* binding : region->bindings_) { plan.lifted_bindings.insert(binding->var.get()); } + if (region->shape_expr_inputs_.size()) { + Array tir_vars; + for (const auto* var : region->shape_expr_inputs_) { + tir_vars.push_back(GetRef(var)); + } + plan.propogated_tir_vars = ShapeExpr(tir_vars); + } plan.inputs.assign(region->inputs_.begin(), region->inputs_.end()); plan.outputs.assign(region->outputs_.begin(), region->outputs_.end()); return plan; @@ -189,6 +225,18 @@ class CUDAGraphRewritePlanner : public ExprVisitor { return plans; } + /*! + * \brief Collect the name hints of the symbolic variables that are allowed to be captured. + */ + void CollectSymbolicVarHints(const Function& func) { + capture_symbolic_vars_.clear(); + if (auto symbolic_vars = + func->attrs.GetAttr>("relax.rewrite_cuda_graph.capture_symbolic_vars")) { + for (const auto& var : symbolic_vars.value()) { + capture_symbolic_vars_.insert(var); + } + } + } /*! *\brief Start a new static region. This method should be called when encountering a * CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on static parameters. @@ -239,8 +287,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor { // Check whether the call can be lifted to the capture function. It requires all the arguments // to be static and the call to be a kernel launch or a pure operation (e.g. memory view). std::vector args; + std::vector tir_vars; bool is_all_static = [&]() { - if (!IsStatic(call->args, &args)) { + if (!IsStatic(call->args, &args, &tir_vars)) { return false; } if (call_gv != nullptr && !call_prim_func) { @@ -276,7 +325,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { StartRegion(); } AddStaticBinding(binding, /*is_alloc_storage=*/false); - MarkAsFuncInput(args); + MarkAsFuncInput(args, tir_vars); } else { EndRegion(); } @@ -284,7 +333,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { MarkAsFuncOutput(args); } - void MarkAsFuncInput(const std::vector& vars) { + void MarkAsFuncInput(const std::vector& vars, + const std::vector& tir_vars = {}) { if (current_.capture_builder == nullptr) { return; } @@ -294,6 +344,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor { current_.capture_builder->MarkInput(var); } } + for (const tir::VarNode* tir_var : tir_vars) { + current_.capture_builder->MarkShapeExprInput(tir_var); + } } void MarkAsFuncOutput(const std::vector& vars) { @@ -321,9 +374,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) final { std::vector args; - if (IsStatic(tuple->fields, &args)) { + std::vector tir_vars; + if (IsStatic(tuple->fields, &args, &tir_vars)) { AddStaticBinding(binding, false); - MarkAsFuncInput(args); + MarkAsFuncInput(args, tir_vars); } else { EndRegion(); } @@ -343,48 +397,83 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } bool IsStatic(const PrimExpr& expr, - [[maybe_unused]] std::vector* vars_collector = nullptr) { - return expr->IsInstance() || expr->IsInstance(); + [[maybe_unused]] std::vector* vars_collector = nullptr, + std::vector* tir_vars_collector = nullptr) { + bool is_static = true; + tir::PostOrderVisit(expr, [&](const ObjectRef& e) { + if (auto var = e.as()) { + if (!capture_symbolic_vars_.count(var->name_hint)) { + is_static = false; + return; + } + if (tir_vars_collector != nullptr) { + tir_vars_collector->push_back(var); + } + } + }); + return is_static; } - bool IsStatic(const Expr& expr, std::vector* vars_collector = nullptr) { + bool IsStatic(const Expr& expr, std::vector* vars_collector = nullptr, + std::vector* tir_vars_collector = nullptr) { if (expr->IsInstance() || expr->IsInstance() || - expr->IsInstance()) { + expr->IsInstance() || expr->IsInstance()) { return true; } if (const auto* prim_value = expr.as()) { - return IsStatic(prim_value->value, vars_collector); + return IsStatic(prim_value->value, vars_collector, tir_vars_collector); } if (const auto* var = expr.as()) { if (vars_collector != nullptr) { vars_collector->push_back(var); } - return static_vars_.count(var); + // recursively check the struct info to collect the symbolic TIR vars + return static_vars_.count(var) && IsStatic(Downcast(var->struct_info_.value()), + vars_collector, tir_vars_collector); } if (const auto* shape = expr.as()) { - return IsStatic(shape->values, vars_collector); + return IsStatic(shape->values, vars_collector, tir_vars_collector); } if (const auto* tuple = expr.as()) { - return IsStatic(tuple->fields, vars_collector); + return IsStatic(tuple->fields, vars_collector, tir_vars_collector); } return false; } template - bool IsStatic(const Array& exprs, std::vector* vars_collector = nullptr) { + bool IsStatic(const Array& exprs, std::vector* vars_collector = nullptr, + std::vector* tir_vars_collector = nullptr) { bool result = true; for (const auto& expr : exprs) { // If vars_collector is provided, we will collect all the vars in the exprs and we should // not perform short-circuiting. - result &= IsStatic(expr, vars_collector); - if (!vars_collector && !result) { + result &= IsStatic(expr, vars_collector, tir_vars_collector); + if (vars_collector == nullptr && tir_vars_collector == nullptr && !result) { return false; } } return result; } + bool IsStatic(const StructInfo& sinfo, std::vector* vars_collector = nullptr, + std::vector* tir_vars_collector = nullptr) { + if (const auto* tensor_sinfo = sinfo.as()) { + if (auto shape = tensor_sinfo->GetShape()) { + return IsStatic(shape.value(), vars_collector, tir_vars_collector); + } + } else if (const auto* shape_sinfo = sinfo.as()) { + if (shape_sinfo->values) { + return IsStatic(shape_sinfo->values.value(), vars_collector, tir_vars_collector); + } + } else if (const auto* tuple_sinfo = sinfo.as()) { + return IsStatic(tuple_sinfo->fields, vars_collector, tir_vars_collector); + } else if (sinfo.as() || sinfo.as()) { + return true; + } + return false; + } + private: bool IsStaticAllocStorage(const VarBindingNode* binding) { // Check if the allocation has constant shape @@ -431,6 +520,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { Scope current_; // Variables whose buffer address is fixed std::unordered_set static_vars_; + // The name of the variables that are allowed to be symbolic + std::unordered_set capture_symbolic_vars_; // Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs // of the lifted function when its binding is used outside. std::unordered_map binding_to_region_; @@ -475,6 +566,8 @@ class CUDAGraphRewriter : public ExprMutator { auto gv_func = builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" : "cuda_graph_capture"); if (plan.is_alloc) { + // Storage allocation should be fully static and shouldn't depend on any symbolic variables. + ICHECK(!plan.propogated_tir_vars.defined()); ICHECK(plan.inputs.empty()); launch_subgraph = Call(call_builtin_with_ctx_op, @@ -482,15 +575,39 @@ class CUDAGraphRewriter : public ExprMutator { Tuple({gv_func, PrimValue(IntImm(DataType::Int(64), index_alloc_++))})}, Attrs(), {plan.func->ret_struct_info}); } else { + StructInfo call_sinfo = plan.func->ret_struct_info; + // Arguments of the lifted function Array args; for (const auto& arg : plan.inputs) { args.push_back(VisitExpr_(arg)); } - launch_subgraph = Call( - call_builtin_with_ctx_op, - {builtin_run_or_capture, - Tuple({gv_func, Tuple(args), PrimValue(IntImm(DataType::Int(64), index_capture_++))})}, - Attrs(), {plan.func->ret_struct_info}); + if (plan.propogated_tir_vars.defined()) { + ShapeExpr propogated_tir_vars = plan.propogated_tir_vars.value(); + args.push_back(propogated_tir_vars); + // The ret_struct_info of the lifted function can contain symbolic variables. We need to + // bind the symbolic parameters to the actual values. + const auto& shape_expr = plan.func->params.back(); + auto symbolic_params = + Downcast(shape_expr->struct_info_.value())->values.value(); + Map tir_var_remap; + ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size()); + for (int i = 0; i < static_cast(symbolic_params.size()); ++i) { + tir_var_remap.Set(Downcast(symbolic_params[i]), propogated_tir_vars->values[i]); + } + call_sinfo = Bind(call_sinfo, tir_var_remap); + } + // Arguments of builtin_run_or_capture + Array tuple_arg_fields{gv_func, Tuple(args), + PrimValue(IntImm(DataType::Int(64), index_capture_++))}; + if (plan.propogated_tir_vars.defined()) { + // The shape expr is explicitly passed twice, one as the last argument of the lifted + // function, one as the last argument of builtin_run_or_capture as the cache key. Explicitly + // passing it twice simplifies the handling during the capture phase. + tuple_arg_fields.push_back(plan.propogated_tir_vars.value()); + } + launch_subgraph = + Call(call_builtin_with_ctx_op, {builtin_run_or_capture, Tuple(tuple_arg_fields)}, Attrs(), + {call_sinfo}); } Expr ret_value = builder_->Emit(launch_subgraph); for (int i = 0; i < static_cast(plan.outputs.size()); ++i) { diff --git a/src/relax/utils.cc b/src/relax/utils.cc index efb2d0220481..cffbafdee2fc 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -144,6 +144,10 @@ Expr Bind(const Expr& expr, const tvm::Map& binds, return ExprBinder(binds, symbolic_var_map).VisitExpr(expr); } +StructInfo Bind(const StructInfo& sinfo, const tvm::Map& symbolic_var_map) { + return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo); +} + tvm::Map InferSymbolicVarMap( const tvm::Map& relax_var_remap, arith::Analyzer* analyzer) { tvm::Map tir_var_remap; diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index f6eef9ca259d..02b6da7dab8d 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -26,11 +26,45 @@ #include #include +#include "../../../support/utils.h" #include "../../cuda/cuda_common.h" namespace tvm { namespace runtime { namespace relax_vm { +struct CUDAGraphCaptureKey { + // The unique index of the capture function within the module + int64_t index; + // The symbolic variables the capture function depends on. When the capture function is ran with + // different symbolic variable values, the CUDA graph will be re-captured as a different version, + // identified by this shape tuple. This is default constructed as an empty tuple. + ShapeTuple shape_expr; + + CUDAGraphCaptureKey(int64_t index, const Optional& shape_expr) : index(index) { + if (shape_expr) { + this->shape_expr = shape_expr.value(); + } + } +}; + +struct CUDAGraphCaptureKeyHash { + size_t operator()(const CUDAGraphCaptureKey& key) const { + std::hash hash_fn; + size_t hash = hash_fn(key.index); + for (const auto& shape : key.shape_expr) { + support::HashCombine(hash, hash_fn(shape)); + } + return hash; + } +}; + +struct CUDAGraphCaptureKeyEqual { + bool operator()(const CUDAGraphCaptureKey& lhs, const CUDAGraphCaptureKey& rhs) const { + return lhs.index == rhs.index && std::equal(lhs.shape_expr.begin(), lhs.shape_expr.end(), + rhs.shape_expr.begin(), rhs.shape_expr.end()); + } +}; + /*! \brief The cache states of a CUDA graph. */ class CUDAGraphCache : public Object { public: @@ -62,8 +96,9 @@ class CUDAGraphCache : public Object { * \return The return value of the capture function. */ ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, ObjectRef args, - int64_t entry_index) { - if (auto it = capture_cache_.find(entry_index); it != capture_cache_.end()) { + int64_t entry_index, Optional shape_expr) { + CUDAGraphCaptureKey entry_key{entry_index, shape_expr}; + if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) { // Launch CUDA graph const auto& [states, exec] = it->second; CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream)); @@ -103,8 +138,8 @@ class CUDAGraphCache : public Object { CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph)); std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream); - capture_cache_[entry_index] = entry; - CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_index].exec, graph, NULL, NULL, 0)); + capture_cache_[entry_key] = entry; + CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, NULL, NULL, 0)); CUDA_CALL(cudaStreamDestroy(capture_stream)); CUDA_CALL(cudaGraphDestroy(graph)); return entry.states; @@ -134,7 +169,9 @@ class CUDAGraphCache : public Object { * \brief The cache of captured cuda graphs. The key is a unique index for the capture function. * The value is the result of the capture. */ - std::unordered_map capture_cache_; + std::unordered_map + capture_cache_; /*! * \brief The cache of allocations. The key is a unique index for the allocation function. * The value is the cached allocations, which is a tuple of storages. @@ -143,11 +180,18 @@ class CUDAGraphCache : public Object { }; TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") - .set_body_typed([](TVMArgValue vm_ptr, ObjectRef capture_func, ObjectRef func_args, - int64_t entry_index) { - VirtualMachine* vm = VirtualMachine::GetContextPtr(vm_ptr); + .set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK(args.size() == 5 || args.size() == 4); + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + ObjectRef capture_func = args[1]; + ObjectRef func_args = args[2]; + int64_t entry_index = args[3]; + Optional shape_expr = NullOpt; + if (args.size() == 5) { + shape_expr = args[4].AsObjectRef(); + } CUDAGraphCache* cache = CUDAGraphCache::Get(); - return cache->RunOrCapture(vm, capture_func, func_args, entry_index); + *rv = cache->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr); }); TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 91b3fce2640a..43b26f110fa2 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -757,5 +757,123 @@ def main() -> R.Tuple: tvm.ir.assert_structural_equal(mod, Expected) +def test_dynamic_capture(): + @I.ir_module + class Before: + @T.prim_func + def add_one(x_handle: T.handle, y_handle: T.handle): + m = T.int64() + x = T.match_buffer(x_handle, (m,), "float32") + y = T.match_buffer(y_handle, (m,), "float32") + for i in range(m): + with T.block("add"): + vi = T.axis.remap("S", [i]) + y[vi] = x[vi] + T.float32(1) + + @R.function + def main(x: R.Tensor(("m",), "float32")) -> R.Tensor(("m",), "float32"): + R.func_attr( + {"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"], "relax.force_pure": True} + ) + m = T.int64() + storage: R.Object = R.memory.alloc_storage( + R.shape([16]), 0, "global", "float32" + ) # assume m is upper-bounded + alloc1: R.Tensor((m,), "float32") = R.memory.alloc_tensor( + storage, 0, R.shape([m]), "float32" + ) + _ = Before.add_one(x, alloc1) + storage1: R.Object = R.memory.alloc_storage(R.shape([16]), 0, "global", "float32") + alloc2: R.Tensor((m,), "float32") = R.memory.alloc_tensor( + storage1, 0, R.shape([m]), "float32" + ) + _ = Before.add_one(alloc1, alloc2) + alloc3: R.Tensor((m,), "float32") = R.builtin.alloc_tensor( + R.shape([m]), "float32", 0, "global" + ) + _ = Before.add_one(alloc2, alloc3) + return alloc3 + + @I.ir_module + class Expected: + @T.prim_func + def add_one(x_handle: T.handle, y_handle: T.handle): + m = T.int64() + x = T.match_buffer(x_handle, (m,)) + y = T.match_buffer(y_handle, (m,)) + # with T.block("root"): + for i in range(m): + with T.block("add"): + vi = T.axis.spatial(m, i) + T.reads(x[vi]) + T.writes(y[vi]) + y[vi] = x[vi] + T.float32(1) + + @R.function(private=True) + def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): + R.func_attr({"relax.force_pure": True}) + storage: R.Object = R.memory.alloc_storage( + R.shape([16]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + storage1: R.Object = R.memory.alloc_storage( + R.shape([16]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + gv: R.Tuple(R.Object, R.Object) = storage, storage1 + return gv + + @R.function(private=True) + def cuda_graph_capture( + alloc1: R.Tensor(("m",), dtype="float32"), + alloc2: R.Tensor(("m",), dtype="float32"), + shape_expr: R.Shape(["m"]), + ): + m = T.int64() + R.func_attr({"relax.force_pure": True}) + cls = Expected + cls.add_one(alloc1, alloc2) + gv = R.tuple() + return R.tuple() + + @R.function + def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float32"): + m = T.int64() + R.func_attr( + {"relax.force_pure": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"]} + ) + cls = Expected + gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.get_cached_alloc", + (cls.cuda_graph_alloc, R.prim_value(0)), + sinfo_args=(R.Tuple(R.Object, R.Object),), + ) + storage: R.Object = gv[0] + alloc1: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([m]), R.dtype("float32") + ) + cls.add_one(x, alloc1) + storage1: R.Object = gv[1] + alloc2: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor( + storage1, R.prim_value(0), R.shape([m]), R.dtype("float32") + ) + R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.run_or_capture", + ( + cls.cuda_graph_capture, + (alloc1, alloc2, R.shape([m])), + R.prim_value(0), + R.shape([m]), + ), + sinfo_args=(R.Tuple,), + ) + alloc3: R.Tensor((m,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([m]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.add_one(alloc2, alloc3) + return alloc3 + + mod = relax.transform.RewriteCUDAGraph()(Before) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()