Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,13 @@ constexpr const char* kWorkspaceSize = "WorkspaceSize";
/*! \brief Override checking purity for this function and treat as pure
* (is_pure must be set to true) */
constexpr const char* kForcePure = "relax.force_pure";

/*!
* \brief The number of inputs of a function.
* If a function has the num_input attribute, the last func->params.size() - num_inputs
* arguments are assumed to be weights that are fixed across invocations.
*/
constexpr const char* kNumInput = "num_input";
} // namespace attr

/*! \brief The extern function, which can represent packed function. */
Expand Down
62 changes: 42 additions & 20 deletions src/relax/backend/vm/vm_shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor {
*/
class VMShapeLowerMutator
: public ExprMutator,
public StructInfoFunctor<void(const StructInfo&, Expr, bool, const String&,
public StructInfoFunctor<void(const StructInfo&, Expr, bool, bool, const String&,
std::vector<MatchShapeTodoItem>*)> {
public:
static IRModule Lower(IRModule mod, bool emit_err_ctx) {
Expand Down Expand Up @@ -241,12 +241,19 @@ class VMShapeLowerMutator
builder_->BeginBindingBlock();
this->builder_->EmitNormalized(shape_heap_binding);
std::vector<MatchShapeTodoItem> match_todos;
size_t num_input = func->params.size();
if (auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
// If the function has the attribute 'num_input', do shape checking on for the real inputs
// and skip weights.
num_input = static_cast<size_t>(opt_num_input.value()->value);
}
for (size_t i = 0; i < func->params.size(); ++i) {
StructInfo sinfo = GetStructInfo(func->params[i]);
std::ostringstream err_ctx;
err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i
<< "], param=" << func->params[i]->name_hint() << ", annotation=" << sinfo << ") ";
this->CheckMatchCast(sinfo, func->params[i], true, err_ctx.str(), &match_todos);
this->CheckMatchCast(sinfo, func->params[i], true, i >= num_input, err_ctx.str(),
&match_todos);
}
// insert heap generation logic.
match_todos = this->RunMatch(match_todos, false);
Expand All @@ -269,7 +276,7 @@ class VMShapeLowerMutator
<< ", loc=return, annotation=" << func->ret_struct_info << ") ";
std::vector<MatchShapeTodoItem> match_todos;
// NOTE: the return value's shape computation must already be defined.
this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, err_ctx.str(),
this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, false, err_ctx.str(),
&match_todos);
// NOTE: the return value's shape computation must already be defined.
this->RunMatch(match_todos, true);
Expand Down Expand Up @@ -377,7 +384,7 @@ class VMShapeLowerMutator
std::ostringstream err_ctx;
err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info << ") ";
// always_check=false
this->CheckMatchCast(binding->struct_info, value, false, err_ctx.str(), &match_todos);
this->CheckMatchCast(binding->struct_info, value, false, false, err_ctx.str(), &match_todos);

match_todos = this->RunMatch(match_todos, false);
this->EmitOutstandingPrimExprCompute();
Expand Down Expand Up @@ -556,37 +563,42 @@ class VMShapeLowerMutator
* \param always_check Whether we insert runtime check even if we can prove
* that value's struct info already satisfies the condition.
* This option is necessary for argument checking per our calling convention.
*
* \param dynamic_only Whether we only check values with dynamic shapes.
* \param err_ctx Extra error context to bring more informative error reporting.
* \param match_todos List of match shape todo items collected when recursively
* visit the match cast.
*/
void CheckMatchCast(const StructInfo& struct_info, Expr value, bool always_check,
const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) {
return this->VisitStructInfo(struct_info, value, always_check, err_ctx, match_todos);
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) {
return this->VisitStructInfo(struct_info, value, always_check, dynamic_only, err_ctx,
match_todos);
}

void VisitStructInfo(const StructInfo& struct_info, Expr value, bool always_check,
const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final {
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
// short-cut, if the struct info already satisfies the
// constraint during match cast, we can skip matching
if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return;
return StructInfoFunctor::VisitStructInfo(struct_info, value, always_check, err_ctx,
match_todos);
return StructInfoFunctor::VisitStructInfo(struct_info, value, always_check, dynamic_only,
err_ctx, match_todos);
}

void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool always_check,
const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final {
}
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {}

void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool always_check,
const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final {
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
// TODO(relax-team) add PrimValue checks later.
LOG(FATAL) << "MatchCast of PrimValue is not yet supported";
}

void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool always_check,
const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final {
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
// emit runtime check of shape
if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), GetStructInfo(value))) {
// check_shape_info(value, ndim, err_ctx)
Expand All @@ -605,8 +617,16 @@ class VMShapeLowerMutator
}

void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool always_check,
const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final {
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
// emit runtime check of shape
auto* shape_expr = op->shape.as<ShapeExprNode>();
if (dynamic_only &&
std::all_of(shape_expr->values.begin(), shape_expr->values.end(),
[](const PrimExpr& e) { return e->IsInstance<IntImmNode>(); })) {
// if we only check dynamic shapes, and the shape is static, we can skip.
return;
}
if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim), GetStructInfo(value))) {
// check_tensor_info(value, ndim, dtype, err_ctx)
Call call(builtin_check_tensor_info_,
Expand All @@ -615,7 +635,7 @@ class VMShapeLowerMutator
builder_->Emit(call, "_");
}

if (auto* shape_expr = op->shape.as<ShapeExprNode>()) {
if (shape_expr != nullptr) {
MatchShapeTodoItem item;
item.input = value;
item.pattern = shape_expr->values;
Expand Down Expand Up @@ -648,7 +668,8 @@ class VMShapeLowerMutator
}

void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool always_check,
const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final {
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
auto* value_tinfo = GetStructInfoAs<TupleStructInfoNode>(value);
if (value_tinfo) {
CHECK_EQ(value_tinfo->fields.size(), op->fields.size())
Expand All @@ -664,13 +685,14 @@ class VMShapeLowerMutator
}
// recursively visit each sub-field and run matching
for (size_t i = 0; i < op->fields.size(); ++i) {
this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), always_check, err_ctx,
match_todos);
this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), always_check, dynamic_only,
err_ctx, match_todos);
}
}

void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool always_check,
const String& err_ctx, std::vector<MatchShapeTodoItem>* match_todos) final {
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
// we only check function is callable.
if (!always_check && MatchStructInfo<FuncStructInfo>(value)) return;
// check_func_info(value, err_ctx)
Expand Down
6 changes: 2 additions & 4 deletions src/relax/transform/bundle_model_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@
namespace tvm {
namespace relax {

static const auto kAttrNumInput = "num_input";

class ModelParamBundler : public ExprMutator {
public:
ModelParamBundler() {}

Expr VisitExpr_(const FunctionNode* op) override {
Function func = GetRef<Function>(op);
auto opt_num_input = func->attrs.GetAttr<Integer>(kAttrNumInput);
auto opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput);
if (!opt_num_input) return func;
auto signed_num_input = opt_num_input.value()->value;

Expand All @@ -68,7 +66,7 @@ class ModelParamBundler : public ExprMutator {
var_to_expr_.Set(func->params[i], TupleGetItem(var_param_tuple, i - num_input));
}

func = WithoutAttr(func, kAttrNumInput);
func = WithoutAttr(func, attr::kNumInput);
func.CopyOnWrite()->params = params;

return ExprMutator::VisitExpr_(func.get());
Expand Down
3 changes: 1 addition & 2 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class TransformParamsLifter : ExprMutator {
private:
Expr VisitExpr_(const FunctionNode* op) override {
auto func = GetRef<Function>(op);
Optional<Integer> opt_num_input = func->attrs.GetAttr<Integer>(attr_num_input_);
Optional<Integer> opt_num_input = func->attrs.GetAttr<Integer>(attr::kNumInput);
if (!opt_num_input) {
return func;
}
Expand Down Expand Up @@ -300,7 +300,6 @@ class TransformParamsLifter : ExprMutator {
return VisitExpr_(static_cast<const VarNode*>(var));
}

const char* attr_num_input_ = "num_input";
// Remap the original parameters to TupleGetItem from the packed tuple of transformed parameters.
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
// The plan of lifting the transform params
Expand Down
3 changes: 1 addition & 2 deletions src/relax/transform/rewrite_cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
if (pair.second->IsInstance<FunctionNode>()) {
// If a function has the num_input attribute, the last func->params.size() - num_inputs
// inputs are assumed to be fixed and thus they can be captured into a cuda graph.
static const char* attr_num_input = "num_input";
const auto& func = Downcast<Function>(pair.second);
if (auto num_input = func->attrs.GetAttr<Integer>(attr_num_input)) {
if (auto num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
for (size_t i = num_input.value().IntValue(); i < func->params.size(); ++i) {
static_vars_.insert(func->params[i].get());
}
Expand Down
Loading