Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 207 additions & 7 deletions src/relax/transform/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,195 @@
namespace tvm {
namespace relax {

namespace {

/* \brief Collect names of functions to be lifted out */
class LambdaNameCollector : ExprVisitor {
public:
static std::unordered_map<const FunctionNode*, String> Collect(const IRModule& mod) {
LambdaNameCollector visitor;

for (const auto& [gvar, base_func] : mod->functions) {
visitor.previous_global_vars_.insert(gvar->name_hint);
}

for (const auto& [gvar, base_func] : mod->functions) {
if (auto func = base_func.as<Function>()) {
visitor.name_stack_.push_back(gvar->name_hint);
visitor(func.value());
visitor.name_stack_.pop_back();
}
}

return visitor.Finalize();
}

private:
void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override {
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
String public_name = opt.value();

// If a kGlobalSymbol exists, we must use the name exactly as it
// appears, with no modifications. Because these errors would
// be raised from deep within an optimization pipeline, but
// depends on small annotation changes from a user's initial
// model definition, they are intentionally verbose to
// (hopefully) provide sufficient context to a user encountering
// the error.
CHECK(!previous_global_vars_.count(public_name))
<< "Function " << name_stack_.front() << " contains a lambda with kGlobalSymbol (\""
<< tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name << "\". "
<< "However, the module already contains a GlobalVar with this name. "
<< "If present, the kGlobalSymbol attribute must match the name of the GlobalVar, "
<< "and GlobalVar names must be unique across an IRModule. "
<< "Lifting the " << public_name << " function out of " << name_stack_.front()
<< " would require violating one of these two conditions.";

auto it = new_public_names_.find(public_name);
CHECK(it == new_public_names_.end())
<< "Function " << name_stack_.front() << " contains a lambda with kGlobalSymbol (\""
<< tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name << "\". "
<< "However, the function " << it->second.front()
<< " also contains a lambda with the same value for kGlobalSymbol. "
<< "If present, the kGlobalSymbol attribute must match the name of the GlobalVar, "
<< "and GlobalVar names must be unique across an IRModule. "
<< "Lifting the " << public_name << " function out of both " << name_stack_.front()
<< " and " << it->second.front()
<< " would require violating one of these two conditions.";

new_public_names_.insert({public_name, name_stack_});
lifted_with_global_symbol_.insert({func, public_name});
}

name_stack_.push_back(binding->var->name_hint());
lambda_location_.insert({func, name_stack_});
ExprVisitor::VisitBinding_(binding, func);
name_stack_.pop_back();
}

// De-duplication of collected names
std::unordered_map<const FunctionNode*, String> Finalize() const {
// The functions which still must be assigned a name
std::unordered_map<const FunctionNode*, Array<String>> remaining_to_name = lambda_location_;

// Collecting the functions that now have a name.
std::unordered_map<const FunctionNode*, String> lifted_names;

// A lookup for names that are unavailable for use.
std::unordered_set<String> unavailable_names = previous_global_vars_;

// A helper function to generate de-duplicated names. The
// `proposed_name_generation_func` should be a function with
// signature:
//
// Optional<String> func(const FunctionNode*, const Array<String>&)
//
// The first argument will be the lambda function being lifted.
// The second argument will be the nested location where that
// lambda function was found. The function should return the
// proposed name for the lifted lambda function. The proposed
// name will be accepted if it does not conflict with any previous
// names, and is unique for all lambda functions being lifted.
//
// This helper function is used to apply several different schemes
// to generate the name of the lifted lambda function. The
// overall goal is to provide names that are unique (required by
// IRModule), deterministic (required for unit testing), and
// human-readable.
auto attempt_name_generation = [&](const auto& proposed_name_generation_func) {
if (remaining_to_name.empty()) {
return;
}

std::unordered_map<String, const FunctionNode*> new_names;
for (const auto& [func, location] : remaining_to_name) {
if (Optional<String> opt_proposed_name = proposed_name_generation_func(func, location)) {
auto proposed_name = opt_proposed_name.value();

if (unavailable_names.count(proposed_name)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the body of the if statement is empty, probably we can merge it with else if/ else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I went back and forth on it. It could be merged with the else if/else, but that would require repeating (and re-evaluating) the condition in the other two branches. It could be pulled out as if(!unavailable_names.count(proposed_name)) to wrap around the other two cases, but this utility is already getting a bit deeply nested for readability.

// The name is already used, either from a GlobalVar, or
// from a previous round of attempted names.
} else if (auto it = new_names.find(proposed_name); it != new_names.end()) {
// The name is not unique within the current attempt. Mark
// the function as nullptr to previous any use of this name
it->second = nullptr;
} else {
// The name is unique so far. Track it for use.
new_names.insert({proposed_name, func});
}
}
}

for (const auto& [name, func] : new_names) {
if (func) {
lifted_names.insert({func, name});
remaining_to_name.erase(func);
}
}
};

// 1. Start with any publicly explosed names from kGlobalSymbol
attempt_name_generation([&](const FunctionNode* func, const auto&) -> Optional<String> {
if (auto it = lifted_with_global_symbol_.find(func); it != lifted_with_global_symbol_.end()) {
return it->second;
} else {
return NullOpt;
}
});

// 2. Try concatenating the name of the relax variable with the
// name of the function that contains it.
attempt_name_generation([&](const FunctionNode*, const auto& location) -> String {
std::stringstream stream;
stream << location.front() << "_" << location.back();
return stream.str();
});

// 3. Try concatenating the entire path together. Don't include
// paths of length 2, as they would already be attempted earlier.
attempt_name_generation([&](const FunctionNode*, const auto& location) -> Optional<String> {
if (location.size() == 2) return NullOpt;

std::stringstream stream;
bool is_first = true;
for (const auto& loc : location) {
if (is_first) {
is_first = false;
} else {
stream << "_";
}
stream << loc;
}
return String(stream.str());
});

// 4. Fallback. Count the number of times a relax variable with
// that name was used.
std::unordered_map<String, int> usage_count;
attempt_name_generation([&](const FunctionNode*, const auto& location) -> String {
std::stringstream stream;
stream << location.front() << "_" << location.back();
int usage = usage_count[stream.str()]++;
stream << "_" << usage;

return stream.str();
});

ICHECK(remaining_to_name.empty())
<< "Fallback failed to make unique names for all lifted lambda functions";

return lifted_names;
}

Array<String> name_stack_;
std::unordered_set<String> previous_global_vars_;
std::unordered_map<String, Array<String>> new_public_names_;
std::unordered_map<const FunctionNode*, String> lifted_with_global_symbol_;
std::unordered_map<const FunctionNode*, Array<String>> lambda_location_;
};

} // namespace

/* The goal of this class is to lift out any nested functions into top-level
* functions.
*
Expand All @@ -42,17 +231,19 @@ namespace relax {
*/
class LambdaLifter : public ExprMutator {
public:
explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = module; }
explicit LambdaLifter(const IRModule& module)
: ExprMutator(module), mod_(module), lifted_names_(LambdaNameCollector::Collect(module)) {}

using ExprMutator::VisitExpr_;

void VisitBinding_(const VarBindingNode* binding) final {
bool is_lambda = false;
if (binding->value->IsInstance<FunctionNode>()) {
is_lambda = true;
bool is_lambda = binding->value->IsInstance<FunctionNode>();
if (is_lambda) {
recur_vars_.push_back(binding->var);
}

Expr new_value = this->VisitExpr(binding->value);

if (new_value->struct_info_.defined() &&
!new_value->struct_info_.same_as(binding->var->struct_info_)) {
binding->var->struct_info_ = GetStructInfo(new_value);
Expand Down Expand Up @@ -136,8 +327,15 @@ class LambdaLifter : public ExprMutator {
Expr VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);

// TODO(@yongwww): consider appending inner func name into the lifted func name
String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++);
String lift_func_name = [&]() {
auto it = lifted_names_.find(func_node);
ICHECK(it != lifted_names_.end())
<< "InternalError: "
<< "Found lambda function during mutation step, "
<< "but it wasn't found during the earlier name-generation step.";
return it->second;
}();

auto global = GlobalVar(lift_func_name);
Array<Var> free_vars = FreeVars(func);
Array<Var> captured_vars;
Expand Down Expand Up @@ -309,7 +507,9 @@ class LambdaLifter : public ExprMutator {
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> lambda_map_;
Array<Var> recur_vars_;
IRModule mod_;
size_t lift_func_num_ = 0;

std::unordered_map<const FunctionNode*, String> lifted_names_;

/*! \brief Cache ops that would be used later to reduce lookup overhead. */
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
Expand Down
Loading