Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ namespace relax {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});

/*!
* \brief Bind the symbolic variables to a StructInfo. This is a helper function usually called by
* other pass functions to help optimizations.
*/
TVM_DLL StructInfo Bind(const StructInfo& sinfo,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map);

/*!
* \brief Infer a binding map for symbolic variables
*
Expand Down
161 changes: 139 additions & 22 deletions src/relax/transform/rewrite_cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt_functor.h>

#include "../../support/arena.h"
#include "../../support/ordered_set.h"
Expand Down Expand Up @@ -82,6 +84,8 @@ struct LiftedFunctionRewritePlan {
std::vector<const VarNode*> outputs;
// The corresponding binding vars in the original function of the inputs of the lifted function
std::vector<const VarNode*> inputs;
// The tir vars in the original function that are propagated to the lifted function
Optional<ShapeExpr> propogated_tir_vars = NullOpt;
};

/*! \brief Builder of the lifted function for cuda graph capturing or allocations */
Expand All @@ -98,6 +102,11 @@ class FuncBuilder : public ExprMutator {
* \param var The variable to mark as input
*/
void MarkInput(const VarNode* var) { inputs_.push_back(var); }
/*!
* \brief Mark a TIR variable as the ShapeExpr input of the new function.
* \param var The variable to mark as input
*/
void MarkShapeExprInput(const tir::VarNode* var) { shape_expr_inputs_.push_back(var); }
/*!
* \brief Mark a variable as the output of the new function. The variable must be the LHS of an
* existing binding in the new function.
Expand All @@ -111,12 +120,27 @@ class FuncBuilder : public ExprMutator {
/*! \brief Build the new function */
Function Build() {
Array<Var> params;
Optional<Var> shape_expr = NullOpt;
if (shape_expr_inputs_.size()) {
Array<PrimExpr> tir_vars;
for (const auto* var : shape_expr_inputs_) {
auto new_var = GetRef<tir::Var>(var).copy_with_suffix("");
tir_var_remap_.Set(GetRef<tir::Var>(var), new_var);
tir_vars.push_back(new_var);
}
shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars));
}
// Set up the parameters
for (const auto* input : inputs_) {
auto new_var = Var(input->name_hint(), Downcast<Optional<StructInfo>>(input->struct_info_));
auto new_var = Var(
input->name_hint(),
VisitExprDepStructInfoField(Downcast<Optional<StructInfo>>(input->struct_info_).value()));
var_remap_[input->vid] = new_var;
params.push_back(new_var);
}
if (shape_expr) {
params.push_back(shape_expr.value());
}
// Emit the function body
builder_->BeginBindingBlock();
for (const auto* binding : bindings_) {
Expand All @@ -137,9 +161,13 @@ class FuncBuilder : public ExprMutator {
return func;
}

PrimExpr VisitPrimExpr(const PrimExpr& expr) { return tir::Substitute(expr, tir_var_remap_); }

support::OrderedSet<const VarNode*> inputs_;
support::OrderedSet<const VarNode*> outputs_;
support::OrderedSet<const tir::VarNode*> shape_expr_inputs_;
std::vector<const VarBindingNode*> bindings_;
Map<tir::Var, PrimExpr> tir_var_remap_;
};

/*!
Expand All @@ -159,6 +187,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
static_vars_.insert(func->params[i].get());
}
}
CollectSymbolicVarHints(func);
VisitExpr(func);
}
}
Expand All @@ -174,6 +203,13 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
for (const auto* binding : region->bindings_) {
plan.lifted_bindings.insert(binding->var.get());
}
if (region->shape_expr_inputs_.size()) {
Array<PrimExpr> tir_vars;
for (const auto* var : region->shape_expr_inputs_) {
tir_vars.push_back(GetRef<PrimExpr>(var));
}
plan.propogated_tir_vars = ShapeExpr(tir_vars);
}
plan.inputs.assign(region->inputs_.begin(), region->inputs_.end());
plan.outputs.assign(region->outputs_.begin(), region->outputs_.end());
return plan;
Expand All @@ -189,6 +225,18 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
return plans;
}

/*!
* \brief Collect the name hints of the symbolic variables that are allowed to be captured.
*/
void CollectSymbolicVarHints(const Function& func) {
capture_symbolic_vars_.clear();
if (auto symbolic_vars =
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars")) {
for (const auto& var : symbolic_vars.value()) {
capture_symbolic_vars_.insert(var);
}
}
}
/*!
*\brief Start a new static region. This method should be called when encountering a
* CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on static parameters.
Expand Down Expand Up @@ -239,8 +287,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
// Check whether the call can be lifted to the capture function. It requires all the arguments
// to be static and the call to be a kernel launch or a pure operation (e.g. memory view).
std::vector<const VarNode*> args;
std::vector<const tir::VarNode*> tir_vars;
bool is_all_static = [&]() {
if (!IsStatic(call->args, &args)) {
if (!IsStatic(call->args, &args, &tir_vars)) {
return false;
}
if (call_gv != nullptr && !call_prim_func) {
Expand Down Expand Up @@ -276,15 +325,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
StartRegion();
}
AddStaticBinding(binding, /*is_alloc_storage=*/false);
MarkAsFuncInput(args);
MarkAsFuncInput(args, tir_vars);
} else {
EndRegion();
}

MarkAsFuncOutput(args);
}

void MarkAsFuncInput(const std::vector<const VarNode*>& vars) {
void MarkAsFuncInput(const std::vector<const VarNode*>& vars,
const std::vector<const tir::VarNode*>& tir_vars = {}) {
if (current_.capture_builder == nullptr) {
return;
}
Expand All @@ -294,6 +344,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
current_.capture_builder->MarkInput(var);
}
}
for (const tir::VarNode* tir_var : tir_vars) {
current_.capture_builder->MarkShapeExprInput(tir_var);
}
}

void MarkAsFuncOutput(const std::vector<const VarNode*>& vars) {
Expand Down Expand Up @@ -321,9 +374,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor {

void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) final {
std::vector<const VarNode*> args;
if (IsStatic(tuple->fields, &args)) {
std::vector<const tir::VarNode*> tir_vars;
if (IsStatic(tuple->fields, &args, &tir_vars)) {
AddStaticBinding(binding, false);
MarkAsFuncInput(args);
MarkAsFuncInput(args, tir_vars);
} else {
EndRegion();
}
Expand All @@ -343,48 +397,83 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}

bool IsStatic(const PrimExpr& expr,
[[maybe_unused]] std::vector<const VarNode*>* vars_collector = nullptr) {
return expr->IsInstance<tir::IntImmNode>() || expr->IsInstance<tir::FloatImmNode>();
[[maybe_unused]] std::vector<const VarNode*>* vars_collector = nullptr,
std::vector<const tir::VarNode*>* tir_vars_collector = nullptr) {
bool is_static = true;
tir::PostOrderVisit(expr, [&](const ObjectRef& e) {
if (auto var = e.as<tir::VarNode>()) {
if (!capture_symbolic_vars_.count(var->name_hint)) {
is_static = false;
return;
}
if (tir_vars_collector != nullptr) {
tir_vars_collector->push_back(var);
}
}
});
return is_static;
}

bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector = nullptr) {
bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector = nullptr,
std::vector<const tir::VarNode*>* tir_vars_collector = nullptr) {
if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>() ||
expr->IsInstance<StringImmNode>()) {
expr->IsInstance<StringImmNode>() || expr->IsInstance<GlobalVarNode>()) {
return true;
}
if (const auto* prim_value = expr.as<PrimValueNode>()) {
return IsStatic(prim_value->value, vars_collector);
return IsStatic(prim_value->value, vars_collector, tir_vars_collector);
}
if (const auto* var = expr.as<VarNode>()) {
if (vars_collector != nullptr) {
vars_collector->push_back(var);
}
return static_vars_.count(var);
// recursively check the struct info to collect the symbolic TIR vars
return static_vars_.count(var) && IsStatic(Downcast<StructInfo>(var->struct_info_.value()),
vars_collector, tir_vars_collector);
}

if (const auto* shape = expr.as<ShapeExprNode>()) {
return IsStatic(shape->values, vars_collector);
return IsStatic(shape->values, vars_collector, tir_vars_collector);
}
if (const auto* tuple = expr.as<TupleNode>()) {
return IsStatic(tuple->fields, vars_collector);
return IsStatic(tuple->fields, vars_collector, tir_vars_collector);
}
return false;
}

template <typename T>
bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>* vars_collector = nullptr) {
bool IsStatic(const Array<T>& exprs, std::vector<const VarNode*>* vars_collector = nullptr,
std::vector<const tir::VarNode*>* tir_vars_collector = nullptr) {
bool result = true;
for (const auto& expr : exprs) {
// If vars_collector is provided, we will collect all the vars in the exprs and we should
// not perform short-circuiting.
result &= IsStatic(expr, vars_collector);
if (!vars_collector && !result) {
result &= IsStatic(expr, vars_collector, tir_vars_collector);
if (vars_collector == nullptr && tir_vars_collector == nullptr && !result) {
return false;
}
}
return result;
}

bool IsStatic(const StructInfo& sinfo, std::vector<const VarNode*>* vars_collector = nullptr,
std::vector<const tir::VarNode*>* tir_vars_collector = nullptr) {
Copy link
Member

@tqchen tqchen Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(refactor nit, no need in this PR) consider use tir::Var ith reference instead to avoid var de-allocation during rewrite

if (const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
if (auto shape = tensor_sinfo->GetShape()) {
return IsStatic(shape.value(), vars_collector, tir_vars_collector);
}
} else if (const auto* shape_sinfo = sinfo.as<ShapeStructInfoNode>()) {
if (shape_sinfo->values) {
return IsStatic(shape_sinfo->values.value(), vars_collector, tir_vars_collector);
}
} else if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
return IsStatic(tuple_sinfo->fields, vars_collector, tir_vars_collector);
} else if (sinfo.as<ObjectStructInfoNode>() || sinfo.as<PrimStructInfoNode>()) {
return true;
}
return false;
}

private:
bool IsStaticAllocStorage(const VarBindingNode* binding) {
// Check if the allocation has constant shape
Expand Down Expand Up @@ -431,6 +520,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
Scope current_;
// Variables whose buffer address is fixed
std::unordered_set<const VarNode*> static_vars_;
// The name of the variables that are allowed to be symbolic
std::unordered_set<String> capture_symbolic_vars_;
// Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs
// of the lifted function when its binding is used outside.
std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
Expand Down Expand Up @@ -475,22 +566,48 @@ class CUDAGraphRewriter : public ExprMutator {
auto gv_func =
builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" : "cuda_graph_capture");
if (plan.is_alloc) {
// Storage allocation should be fully static and shouldn't depend on any symbolic variables.
ICHECK(!plan.propogated_tir_vars.defined());
ICHECK(plan.inputs.empty());
launch_subgraph =
Call(call_builtin_with_ctx_op,
{builtin_get_cached_alloc,
Tuple({gv_func, PrimValue(IntImm(DataType::Int(64), index_alloc_++))})},
Attrs(), {plan.func->ret_struct_info});
} else {
StructInfo call_sinfo = plan.func->ret_struct_info;
// Arguments of the lifted function
Array<Expr> args;
for (const auto& arg : plan.inputs) {
args.push_back(VisitExpr_(arg));
}
launch_subgraph = Call(
call_builtin_with_ctx_op,
{builtin_run_or_capture,
Tuple({gv_func, Tuple(args), PrimValue(IntImm(DataType::Int(64), index_capture_++))})},
Attrs(), {plan.func->ret_struct_info});
if (plan.propogated_tir_vars.defined()) {
ShapeExpr propogated_tir_vars = plan.propogated_tir_vars.value();
args.push_back(propogated_tir_vars);
// The ret_struct_info of the lifted function can contain symbolic variables. We need to
// bind the symbolic parameters to the actual values.
const auto& shape_expr = plan.func->params.back();
auto symbolic_params =
Downcast<ShapeStructInfo>(shape_expr->struct_info_.value())->values.value();
Map<tir::Var, PrimExpr> tir_var_remap;
ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size());
for (int i = 0; i < static_cast<int>(symbolic_params.size()); ++i) {
tir_var_remap.Set(Downcast<tir::Var>(symbolic_params[i]), propogated_tir_vars->values[i]);
}
call_sinfo = Bind(call_sinfo, tir_var_remap);
}
// Arguments of builtin_run_or_capture
Array<Expr> tuple_arg_fields{gv_func, Tuple(args),
PrimValue(IntImm(DataType::Int(64), index_capture_++))};
if (plan.propogated_tir_vars.defined()) {
// The shape expr is explicitly passed twice, one as the last argument of the lifted
// function, one as the last argument of builtin_run_or_capture as the cache key. Explicitly
// passing it twice simplifies the handling during the capture phase.
tuple_arg_fields.push_back(plan.propogated_tir_vars.value());
}
launch_subgraph =
Call(call_builtin_with_ctx_op, {builtin_run_or_capture, Tuple(tuple_arg_fields)}, Attrs(),
{call_sinfo});
}
Expr ret_value = builder_->Emit(launch_subgraph);
for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
Expand Down
4 changes: 4 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
}

StructInfo Bind(const StructInfo& sinfo, const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map) {
return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo);
}

tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer* analyzer) {
tvm::Map<tir::Var, PrimExpr> tir_var_remap;
Expand Down
Loading