Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ struct BaseCollectInfo {
Function func(params, body, GetStructInfo(tuple_var));
func = WithAttr(func, attr::kNumInput, Integer(0));
func = CopyWithNewVars(func);
func = BundleModelParams(func);
func = Downcast<Function>(CanonicalizeBindings(func));
func = Downcast<Function>(RemoveAllUnused(func));

return func;
}
};
Expand Down Expand Up @@ -725,11 +728,12 @@ std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
target_functions.push_back({gvar.value(), func.value()});
}
} else {
// Get all the functions that have the `num_input` attribute.
// Get all the functions that have the `num_input` attribute, and
// are not already the result of `LiftTransformParams`.
for (const auto& [gvar, func] : mod->functions) {
if (func->IsInstance<FunctionNode>()) {
auto opt_num_input = func->GetAttr<Integer>(attr::kNumInput);
if (opt_num_input) {
if (opt_num_input && !ends_with(gvar->name_hint, "transform_params")) {
target_functions.emplace_back(gvar, Downcast<Function>(func));
}
}
Expand All @@ -748,7 +752,6 @@ namespace transform {

Pass PartitionTransformParams(Variant<Bool, Array<String>> shared_transform) {
auto pass_func = [=](IRModule mod, PassContext pc) {
IRModule updates;
std::optional<GlobalCollectInfo> global_collect_info;

CHECK(shared_transform.defined()) << "shared_transform is not defined";
Expand All @@ -772,24 +775,41 @@ Pass PartitionTransformParams(Variant<Bool, Array<String>> shared_transform) {
local_collect_info[gvar] = info;
}

IRModule updated_runtime_functions;

for (const auto& [gvar, info] : local_collect_info) {
auto new_runtime_func = info.MakeRuntimeFunction();
updates->Add(gvar, new_runtime_func);
updated_runtime_functions->Add(gvar, new_runtime_func);
}

Map<String, Function> lifted_transform_functions;
if (global_collect_info.has_value()) {
auto global_transform = global_collect_info.value().MakeCompileTimeFunc();
updates->Add(GlobalVar("transform_params"), global_transform);
lifted_transform_functions.Set("transform_params", global_transform);
} else {
for (const auto& [gvar, info] : local_collect_info) {
// transform_params is emitted for each function if global lifting is not enabled
updates->Add(GlobalVar(gvar->name_hint + "_transform_params"),
info.MakeCompileTimeFunction());
lifted_transform_functions.Set(gvar->name_hint + "_transform_params",
info.MakeCompileTimeFunction());
}
}

if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
if (updated_runtime_functions->functions.size() || lifted_transform_functions.size()) {
auto write_ptr = mod.CopyOnWrite();
write_ptr->Update(updated_runtime_functions);

for (auto [name, transform] : lifted_transform_functions) {
if (auto opt = write_ptr->global_var_map_.Get(name)) {
auto old_gvar = opt.value();
auto old_transform = Downcast<Function>(write_ptr->Lookup(old_gvar));
write_ptr->Remove(old_gvar);

transform = ComposeFunctions(old_transform, transform);
}
GlobalVar new_gvar(name);
UpdateStructInfo(new_gvar, GetStructInfo(transform));
write_ptr->Add(new_gvar, transform);
}
}

return mod;
Expand Down Expand Up @@ -817,7 +837,6 @@ Pass LiftTransformParams(Variant<Bool, Array<String>> shared_transform) {
std::string func_name = gvar->name_hint;
if (ends_with(func_name, "transform_params")) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
func = BundleModelParams(func);
if (pc->GetConfig<Bool>(kLiftTransformConsumeParams).value_or(Bool(false))) {
func = Downcast<Function>(ConsumeBundledParams()(func));
}
Expand Down
51 changes: 51 additions & 0 deletions src/relax/transform/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#include "utils.h"

#include <tvm/relax/analysis.h>

namespace tvm {
namespace relax {

Expand All @@ -41,5 +43,54 @@ bool IsNestedTensor(const StructInfo& sinfo) {

bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetStructInfo(expr)); }

Function ComposeFunctions(Function func_a, Function func_b) {
Array<Binding> bindings;

Var func_a_output("func_a_output", func_a->ret_struct_info);

bindings.push_back(VarBinding(func_a_output, func_a->body));

auto func_a_outputs = [&]() -> Array<Expr> {
if (auto func_a_output_tuple = func_a->ret_struct_info.as<TupleStructInfoNode>()) {
Array<Expr> outputs;
for (size_t i = 0; i < func_a_output_tuple->fields.size(); i++) {
outputs.push_back(TupleGetItem(func_a_output, i));
}
return outputs;
} else {
return {func_a_output};
}
}();

if (func_b->params.size() == 1 && func_b->params[0]->struct_info_.as<TupleStructInfoNode>()) {
// Special case where the output of the first function is a tuple
// that should be provided as-is to the second function, and
// should not be unpacked into individual elements.
auto param = func_b->params[0];
bindings.push_back(MatchCast(param, func_a_output, GetStructInfo(param)));
} else {
CHECK_EQ(func_a_outputs.size(), func_b->params.size())
<< "ValueError: "
<< "Cannot compose functions together. "
<< "First function produces " << func_a_outputs.size() << " values, "
<< "but second function expects " << func_b->params.size() << " parameters as input";
for (size_t i = 0; i < func_a_outputs.size(); i++) {
auto param = func_b->params[i];
bindings.push_back(MatchCast(param, func_a_outputs[i], GetStructInfo(param)));
}
}

auto new_body = SeqExpr({BindingBlock(bindings)}, func_b->body);

auto new_function = Function(func_a->params, new_body, func_b->ret_struct_info,
func_a->is_pure && func_b->is_pure, func_a->attrs);

new_function = CopyWithNewVars(new_function);
new_function = Downcast<Function>(CanonicalizeBindings(new_function));
new_function = Downcast<Function>(RemoveAllUnused(new_function));

return new_function;
}

} // namespace relax
} // namespace tvm
14 changes: 14 additions & 0 deletions src/relax/transform/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,20 @@ Expr CanonicalizeBindings(Expr expr);
*/
Function BundleModelParams(const Function& func, Optional<String> param_tuple_name = NullOpt);

/*! \brief Compose two functions
*
* Given two functions `func_a` and `func_b`, produce `func_c` such
* that `func_c(x)` is equivalent to `func_b(func_a(x))`.
*
* If the output if `func_a` is not usable as the input of `func_b`,
* an error will be raised.
*
* \param func_a The first function to be composed.
* \param func_b The second function to be composed.
* \return The composed function
*/
TVM_DLL Function ComposeFunctions(Function func_a, Function func_b);

} // namespace relax
} // namespace tvm

Expand Down
Loading