Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _initialize_virtual_device(item, _):
"relay.RefRead": _initialize_virtual_device,
"relay.RefWrite": _initialize_virtual_device,
"relay.Match": _initialize_virtual_device,
"relay.Constant": _initialize_virtual_device,
}

return create_updater(node_map, "0.8", "0.9")
Expand Down
7 changes: 3 additions & 4 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class ExtractConstantsMutator : public MixedModeMutator {
auto new_body = VisitExpr(func->body);
functions_.pop_back();
if (function_to_constants_[func].size()) {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
}
return std::move(func);
}
Expand Down Expand Up @@ -159,8 +159,7 @@ IRModule ExtractConstants(const IRModule& mod) {
auto new_main_body = extract_constants.VisitExpr(main_func->body);
if (!new_main_body.same_as(main_func->body)) {
auto main_var = mod->GetGlobalVar("main");
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
main_func->type_params, main_func->attrs);
Function new_main_func = WithFields(main_func, main_func->params, new_main_body);
mod->Update(main_var, new_main_func);
}
return mod;
Expand Down
9 changes: 2 additions & 7 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,8 @@ class RelayToTIRVisitor : public MixedModeMutator {

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = Downcast<Function>(ir_module_->Lookup(main_global_var));
Function mutated_main = WithFields(main, main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);

Expand Down
8 changes: 2 additions & 6 deletions src/relay/backend/contrib/ethosu/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,8 @@ class RelayToTIRMutator : public MixedModeMutator {

IRModule operator()() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
Function main_func = Downcast<Function>(ir_module_->Lookup(main_global_var));

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = Downcast<Function>(ir_module_->Lookup(main_global_var));
Function mutated_main = WithFields(main, main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);
ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,8 @@ class ConvertAddToSubtract : public MixedModeMutator {

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);
Function main = GetRef<Function>(ir_module_->Lookup(main_global_var).as<FunctionNode>());
Function mutated_main = WithFields(main, main->params, VisitExpr(main->body));

ir_module_->Update(main_global_var, mutated_main);

Expand Down
12 changes: 10 additions & 2 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class TECompilerImpl : public TECompilerNode {
}

IRModule GetLoweredFunctions() {
VLOG(1) << "GetLoweredFunctions";
IRModule mod;
// Extract lowered functions from the cache
for (const auto& it : cache_) {
Expand Down Expand Up @@ -164,8 +165,15 @@ class TECompilerImpl : public TECompilerNode {
for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
if (const auto* function_node = kv2.second.as<FunctionNode>()) {
// Abandon the existing function annotations.
Function function(function_node->params, function_node->body, function_node->ret_type,
function_node->type_params, /*attrs=*/{}, function_node->span);

// Unfortuantely, Optional<DictAttrs>() is indistinguishable from
// NullValue<DictAttrs>(), and DictAttrs() is nullptr, so to erase the attributes, we
// need pass in DictAttrs<Map<String, ObjectRef>()), which is a DictAttrs containing no
// attributes.
Function function =
WithFields(GetRef<Function>(function_node), function_node->params,
function_node->body, function_node->ret_type, function_node->type_params,
/* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
// Mark function as 'extern' using the "ExternalSymbol" attribute.
function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
module->Add(kv2.first, function);
Expand Down
6 changes: 2 additions & 4 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {

if (function_nesting() == 1) {
// We don't need to lift global functions.
return Function(func_node->params, VisitExpr(func_node->body), func_node->ret_type,
func_node->type_params, func_node->attrs, func_node->span);
return WithFields(GetRef<Function>(func_node), func_node->params, VisitExpr(func_node->body));
}

auto name = GenerateName(func);
Expand Down Expand Up @@ -188,8 +187,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
// construct the "closure" function with fully annotated arguments, no longer relying
// on type inference.
size_t before_arity = body->params.size();
auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type,
func->type_params, func->attrs, func->span);
auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map));
size_t after_arity = rebound_body->params.size();
CHECK_EQ(before_arity, after_arity);
lifted_func =
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ using namespace tvm::runtime;
Constant::Constant(runtime::NDArray data, Span span) {
ObjectPtr<ConstantNode> n = make_object<ConstantNode>();
n->data = std::move(data);
n->virtual_device_ = VirtualDevice::FullyUnconstrained();
n->span = std::move(span);
data_ = std::move(n);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/quantize/annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Pass QuantizeAnnotate() {
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs);
return WithFields(func, new_params);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
Expand Down
9 changes: 7 additions & 2 deletions src/relay/quantize/calibrate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,13 @@ class StatsCollector : private ExprMutator {
const FunctionNode* func = new_e.as<FunctionNode>();
ICHECK(func) << "Input shoule be Function";
Expr new_body = Tuple(std::move(profile_data_));
return Function(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
Function ret_func = WithFields(GetRef<Function>(func), FreeVars(new_body), new_body);

// We are changing the function's ret_type to an empty type. Unfortunately, Optional<Type>() is
// indistinguishable from NullValue<Type>(), so we can't express "update to nullptr" in
// WithFields.
ret_func.CopyOnWrite()->ret_type = NullValue<Type>();
return ret_func;
}

private:
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
func = Downcast<Function>(post);
new_body = InsertCompilerEndAndPropogateTarget(func->body);
}
return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
return WithFields(func, func->params, new_body);
}

Expr Rewrite_(const LetNode* op, const Expr& post) override {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/convert_sparse_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,12 @@ Pass Conv2dToSparse(const Array<ObjectRef>& weight_name, const Array<Array<PrimE
auto f0 =
Downcast<Function>(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size));
Array<Var> sparse_params = FreeVars(f0);
auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs);
auto f1 = WithFields(f0, sparse_params);
Array<Var> params = FreeVars(f1);
for (const auto& var : sparse_params) {
params.push_back(var);
}
return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs);
return WithFields(f1, params);
};
return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"});
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/transforms/convert_sparse_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ Pass DenseToSparse(const Array<ObjectRef>& weight_name,
// Remove FreeVar warnings
auto f0 = Downcast<Function>(DenseToSparse(f, weight_name, weight_shape));
Array<Var> sparse_params = FreeVars(f0);
auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs);
auto f1 = WithFields(f0, sparse_params);
Array<Var> params = FreeVars(f1);
for (const auto& var : sparse_params) {
params.push_back(var);
}
return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs);
return WithFields(f1, params);
};
return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"});
}
Expand Down
9 changes: 5 additions & 4 deletions src/relay/transforms/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ Expr DeDup(const Expr& e) {

Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; }

Expr VisitExpr_(const FunctionNode* op) final {
Expr VisitExpr_(const FunctionNode* func_node) final {
tvm::Array<TypeVar> type_params;
for (const TypeVar& type_param : op->type_params) {
for (const TypeVar& type_param : func_node->type_params) {
type_params.push_back(Fresh(type_param));
}
tvm::Array<Var> params;
for (const Var& param : op->params) {
for (const Var& param : func_node->params) {
params.push_back(Fresh(param));
}
return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs);
return WithFields(GetRef<Function>(func_node), params, VisitExpr(func_node->body),
VisitType(func_node->ret_type), type_params);
}

Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); }
Expand Down
7 changes: 4 additions & 3 deletions src/relay/transforms/defunctionalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class DefuncMutator : public ExprMutator {

auto apply_gv = GetApplyFunction(ft);
auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map));
AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params),
AddApplyCase(apply_gv, ft, c, WithFields(GetRef<Function>(fn), fn->params, body),
pattern_vars);

return Call(c, call_args);
Expand Down Expand Up @@ -380,7 +380,7 @@ class DefuncMutator : public ExprMutator {
map.Set(f->type_params[i], type_args[i]);
}
// copy with typevars removed
auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map);
auto copy = TypeSubst(WithFields(f, {}, {}, {}, /* erase type params */ Array<TypeVar>()), map);
return Downcast<Function>(copy);
}

Expand Down Expand Up @@ -410,7 +410,8 @@ class DefuncMutator : public ExprMutator {
}

auto bind = Downcast<Function>(Bind(f, var_bind_map));
return Function(params, this->VisitExpr(bind->body), bind->ret_type, {});
return WithFields(bind, params, this->VisitExpr(bind->body), bind->ret_type,
/* erase type params */ Array<TypeVar>());
}
};

Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/eta_expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ class EtaExpander : public ExprMutator {
params.push_back(var);
args.push_back(var);
}

return Function(args, Call(gvar, params), func->ret_type, func->type_params);
return WithFields(func, args, Call(gvar, params));
} else {
return std::move(gvar);
}
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/first_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,9 @@ Pass FirstOrderGradient() {
});
return Pair(res.forward, grad_tuple);
});
ad_mod->Update(pr.first,
Function(func->params, body, GradRetType(GetRef<Function>(func)), {}));
ad_mod->Update(pr.first, WithFields(GetRef<Function>(func), func->params, body,
GradRetType(GetRef<Function>(func)),
/* erase type params */ Array<TypeVar>()));
}

return ad_mod;
Expand Down
19 changes: 10 additions & 9 deletions src/relay/transforms/higher_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,28 +341,28 @@ struct ReverseAD : ExprMutator {
GlobalVar gv(op->name_hint + "_grad");
(*ad_gvars)[orig_gv] = gv;
Function orig_f = Downcast<Function>(DeDup(mod.value()->Lookup(orig_gv)));
std::vector<Var> params;
Array<Var> params;
for (const auto& p : orig_f->params) {
params.push_back(Downcast<Var>(VisitExpr(p)));
}
params.push_back(bp);
Expr body = VisitExpr(orig_f->body);
Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs);
Function f = WithFields(orig_f, params, VisitExpr(orig_f->body), VisitType(orig_f->ret_type));
std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl;
mod.value()->Add(gv, f);
}
return ad_gvars->at(orig_gv);
}

Expr VisitExpr_(const FunctionNode* op) final {
std::vector<Var> params;
for (const auto& var : op->params) {
Expr VisitExpr_(const FunctionNode* func_node) final {
Array<Var> params;
for (const auto& var : func_node->params) {
params.push_back(Downcast<Var>(VisitExpr(var)));
}
auto new_bp = Var("bp", bpt);
params.push_back(new_bp);
return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body),
VisitType(op->ret_type), op->type_params, op->attrs);
return WithFields(GetRef<Function>(func_node), params,
ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body),
VisitType(func_node->ret_type));
}

Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; }
Expand Down Expand Up @@ -456,7 +456,8 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
};
return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
});
auto ret = Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
Function ret = WithFields(GetRef<Function>(f), f->params, body, GradRetType(GetRef<Function>(f)),
/* erase type params */ Array<TypeVar>());
CheckFeature(ret, FeatureSet::All() - fGraph);
return std::move(ret);
}
Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ class Inliner : ExprMutator {
}

Function Inline(const Function& func) {
return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
return WithFields(func, func->params, VisitExpr(func->body));
}

private:
Expand Down Expand Up @@ -131,6 +130,8 @@ class Inliner : ExprMutator {
const auto* fn = base_func.as<FunctionNode>();
ICHECK(fn) << "Expected to work on a Relay function.";

// There is an inconsistency here, the function itself gets shallow-copied but the body is not
// shallow-copied.
auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
Expand Down
24 changes: 12 additions & 12 deletions src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -827,18 +827,18 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) {
return store_.Extend<Expr>([&]() {
store_.Invalidate();
return Function(func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
for (const auto& v : func->params) {
pv.push_back(NoStatic(v));
}
tvm::Array<Type> type_args;
for (const auto& tp : func->type_params) {
type_args.push_back(tp);
}
return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic;
}),
func->ret_type, func->type_params, func->attrs);
return WithFields(
func, func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
for (const auto& v : func->params) {
pv.push_back(NoStatic(v));
}
tvm::Array<Type> type_args;
for (const auto& tp : func->type_params) {
type_args.push_back(tp);
}
return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic;
}));
});
}

Expand Down
Loading