diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index e3ed24cd9ed7..c7caeab05596 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -34,6 +34,195 @@ namespace tvm { namespace relax { +namespace { + +/* \brief Collect names of functions to be lifted out */ +class LambdaNameCollector : ExprVisitor { + public: + static std::unordered_map Collect(const IRModule& mod) { + LambdaNameCollector visitor; + + for (const auto& [gvar, base_func] : mod->functions) { + visitor.previous_global_vars_.insert(gvar->name_hint); + } + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + visitor.name_stack_.push_back(gvar->name_hint); + visitor(func.value()); + visitor.name_stack_.pop_back(); + } + } + + return visitor.Finalize(); + } + + private: + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override { + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + String public_name = opt.value(); + + // If a kGlobalSymbol exists, we must use the name exactly as it + // appears, with no modifications. Because these errors would + // be raised from deep within an optimization pipeline, but + // depends on small annotation changes from a user's initial + // model definition, they are intentionally verbose to + // (hopefully) provide sufficient context to a user encountering + // the error. + CHECK(!previous_global_vars_.count(public_name)) + << "Function " << name_stack_.front() << " contains a lambda with kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name << "\". " + << "However, the module already contains a GlobalVar with this name. " + << "If present, the kGlobalSymbol attribute must match the name of the GlobalVar, " + << "and GlobalVar names must be unique across an IRModule. " + << "Lifting the " << public_name << " function out of " << name_stack_.front() + << " would require violating one of these two conditions."; + + auto it = new_public_names_.find(public_name); + CHECK(it == new_public_names_.end()) + << "Function " << name_stack_.front() << " contains a lambda with kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name << "\". " + << "However, the function " << it->second.front() + << " also contains a lambda with the same value for kGlobalSymbol. " + << "If present, the kGlobalSymbol attribute must match the name of the GlobalVar, " + << "and GlobalVar names must be unique across an IRModule. " + << "Lifting the " << public_name << " function out of both " << name_stack_.front() + << " and " << it->second.front() + << " would require violating one of these two conditions."; + + new_public_names_.insert({public_name, name_stack_}); + lifted_with_global_symbol_.insert({func, public_name}); + } + + name_stack_.push_back(binding->var->name_hint()); + lambda_location_.insert({func, name_stack_}); + ExprVisitor::VisitBinding_(binding, func); + name_stack_.pop_back(); + } + + // De-duplication of collected names + std::unordered_map Finalize() const { + // The functions which still must be assigned a name + std::unordered_map> remaining_to_name = lambda_location_; + + // Collecting the functions that now have a name. + std::unordered_map lifted_names; + + // A lookup for names that are unavailable for use. + std::unordered_set unavailable_names = previous_global_vars_; + + // A helper function to generate de-duplicated names. The + // `proposed_name_generation_func` should be a function with + // signature: + // + // Optional func(const FunctionNode*, const Array&) + // + // The first argument will be the lambda function being lifted. + // The second argument will be the nested location where that + // lambda function was found. The function should return the + // proposed name for the lifted lambda function. The proposed + // name will be accepted if it does not conflict with any previous + // names, and is unique for all lambda functions being lifted. + // + // This helper function is used to apply several different schemes + // to generate the name of the lifted lambda function. The + // overall goal is to provide names that are unique (required by + // IRModule), deterministic (required for unit testing), and + // human-readable. + auto attempt_name_generation = [&](const auto& proposed_name_generation_func) { + if (remaining_to_name.empty()) { + return; + } + + std::unordered_map new_names; + for (const auto& [func, location] : remaining_to_name) { + if (Optional opt_proposed_name = proposed_name_generation_func(func, location)) { + auto proposed_name = opt_proposed_name.value(); + + if (unavailable_names.count(proposed_name)) { + // The name is already used, either from a GlobalVar, or + // from a previous round of attempted names. + } else if (auto it = new_names.find(proposed_name); it != new_names.end()) { + // The name is not unique within the current attempt. Mark + // the function as nullptr to previous any use of this name + it->second = nullptr; + } else { + // The name is unique so far. Track it for use. + new_names.insert({proposed_name, func}); + } + } + } + + for (const auto& [name, func] : new_names) { + if (func) { + lifted_names.insert({func, name}); + remaining_to_name.erase(func); + } + } + }; + + // 1. Start with any publicly explosed names from kGlobalSymbol + attempt_name_generation([&](const FunctionNode* func, const auto&) -> Optional { + if (auto it = lifted_with_global_symbol_.find(func); it != lifted_with_global_symbol_.end()) { + return it->second; + } else { + return NullOpt; + } + }); + + // 2. Try concatenating the name of the relax variable with the + // name of the function that contains it. + attempt_name_generation([&](const FunctionNode*, const auto& location) -> String { + std::stringstream stream; + stream << location.front() << "_" << location.back(); + return stream.str(); + }); + + // 3. Try concatenating the entire path together. Don't include + // paths of length 2, as they would already be attempted earlier. + attempt_name_generation([&](const FunctionNode*, const auto& location) -> Optional { + if (location.size() == 2) return NullOpt; + + std::stringstream stream; + bool is_first = true; + for (const auto& loc : location) { + if (is_first) { + is_first = false; + } else { + stream << "_"; + } + stream << loc; + } + return String(stream.str()); + }); + + // 4. Fallback. Count the number of times a relax variable with + // that name was used. + std::unordered_map usage_count; + attempt_name_generation([&](const FunctionNode*, const auto& location) -> String { + std::stringstream stream; + stream << location.front() << "_" << location.back(); + int usage = usage_count[stream.str()]++; + stream << "_" << usage; + + return stream.str(); + }); + + ICHECK(remaining_to_name.empty()) + << "Fallback failed to make unique names for all lifted lambda functions"; + + return lifted_names; + } + + Array name_stack_; + std::unordered_set previous_global_vars_; + std::unordered_map> new_public_names_; + std::unordered_map lifted_with_global_symbol_; + std::unordered_map> lambda_location_; +}; + +} // namespace + /* The goal of this class is to lift out any nested functions into top-level * functions. * @@ -42,17 +231,19 @@ namespace relax { */ class LambdaLifter : public ExprMutator { public: - explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = module; } + explicit LambdaLifter(const IRModule& module) + : ExprMutator(module), mod_(module), lifted_names_(LambdaNameCollector::Collect(module)) {} using ExprMutator::VisitExpr_; void VisitBinding_(const VarBindingNode* binding) final { - bool is_lambda = false; - if (binding->value->IsInstance()) { - is_lambda = true; + bool is_lambda = binding->value->IsInstance(); + if (is_lambda) { recur_vars_.push_back(binding->var); } + Expr new_value = this->VisitExpr(binding->value); + if (new_value->struct_info_.defined() && !new_value->struct_info_.same_as(binding->var->struct_info_)) { binding->var->struct_info_ = GetStructInfo(new_value); @@ -136,8 +327,15 @@ class LambdaLifter : public ExprMutator { Expr VisitExpr_(const FunctionNode* func_node) final { auto func = GetRef(func_node); - // TODO(@yongwww): consider appending inner func name into the lifted func name - String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); + String lift_func_name = [&]() { + auto it = lifted_names_.find(func_node); + ICHECK(it != lifted_names_.end()) + << "InternalError: " + << "Found lambda function during mutation step, " + << "but it wasn't found during the earlier name-generation step."; + return it->second; + }(); + auto global = GlobalVar(lift_func_name); Array free_vars = FreeVars(func); Array captured_vars; @@ -309,7 +507,9 @@ class LambdaLifter : public ExprMutator { std::unordered_map lambda_map_; Array recur_vars_; IRModule mod_; - size_t lift_func_num_ = 0; + + std::unordered_map lifted_names_; + /*! \brief Cache ops that would be used later to reduce lookup overhead. */ const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index d67248417173..8f3daa06e200 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -19,7 +19,7 @@ import tvm.testing from tvm import relax import tvm.script -from tvm.script import relax as R, tir as T +from tvm.script import relax as R, tir as T, ir as I from tvm.relax import transform from tvm.ir.base import assert_structural_equal @@ -39,11 +39,13 @@ def _check_save_roundtrip(x): def test_basic(): + """Functions can be listed from local bindings to the IRModule""" + # the target IRModule @tvm.script.ir_module class Expected: @R.function(private=True) - def lifted_func_0( + def main_inner( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): s: R.Tensor((10, 5), "float32") = R.add(x2, y2) @@ -53,7 +55,7 @@ def lifted_func_0( def main( x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): - inner = Expected.lifted_func_0 + inner = Expected.main_inner gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) return gv1 @@ -83,6 +85,8 @@ def inner( def test_closure(): + """Lifting functions may require producing closures""" + # the expected IRModule @tvm.script.ir_module class Expected: @@ -90,7 +94,7 @@ class Expected: def main( x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): - outer_func = Expected.lifted_func_0 + outer_func = Expected.main_outer_func in_call = outer_func(x) res = R.invoke_pure_closure( in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32")) @@ -98,13 +102,13 @@ def main( return res @R.function(private=True) - def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): + def main_inner_func(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1) return r_1 @R.function(private=True) - def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object: - inner_func = R.make_closure(Expected.lifted_func_1, (y,)) + def main_outer_func(y: R.Tensor((2, 3), "float32")) -> R.Object: + inner_func = R.make_closure(Expected.main_inner_func, (y,)) return inner_func # IRModule to perform Lambda Lifting @@ -137,11 +141,13 @@ def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def test_recursive(): + """The lifted function may be recursively defined""" + # the expected IRModule @tvm.script.ir_module class Expected: @R.function(private=True) - def lifted_func_0( + def main_while_loop( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): cond: R.Tensor((), "bool") = R.call_pure_packed( @@ -151,7 +157,7 @@ def lifted_func_0( if cond: new_i: R.Tensor((), "int32") = R.add(i, c) new_s: R.Tensor((2, 3), "float32") = R.add(s, x) - new_r = Expected.lifted_func_0(new_i, new_s, x) + new_r = Expected.main_while_loop(new_i, new_s, x) r = new_r else: r = s @@ -159,7 +165,7 @@ def lifted_func_0( @R.function def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), dtype="float32"): - while_loop = R.make_closure(Expected.lifted_func_0, (x,)) + while_loop = R.make_closure(Expected.main_while_loop, (x,)) gv: R.Tensor((2, 3), dtype="float32") = R.invoke_pure_closure( while_loop, (R.const(0), x), @@ -205,6 +211,12 @@ def while_loop( def test_multi_func(): + """Lifting may be required for multiple top-level functions + + De-duplication of GlobalVar names at the IRModule is done by + appending the name of the function from which they were lifted. + """ + # expected IRModule @tvm.script.ir_module class Expected: @@ -212,7 +224,7 @@ class Expected: def glob_func_1( x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") ) -> R.Tensor(None, "float32", ndim=2): - inner = Expected.lifted_func_0 + inner = Expected.glob_func_1_inner gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) return gv1 @@ -220,19 +232,19 @@ def glob_func_1( def glob_func_2( x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), "float32") ) -> R.Tensor(None, "float32", ndim=2): - inner = Expected.lifted_func_1 + inner = Expected.glob_func_2_inner gv11: R.Tensor((10, 5), "float32") = inner(x11, y11) return gv11 @R.function(private=True) - def lifted_func_0( + def glob_func_1_inner( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): s: R.Tensor((10, 5), "float32") = R.add(x2, y2) return s @R.function(private=True) - def lifted_func_1( + def glob_func_2_inner( x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): s1: R.Tensor((10, 5), "float32") = R.add(x21, y21) @@ -309,13 +321,13 @@ def test_impure_function(): @tvm.script.ir_module class Expected: @R.function(pure=False, private=True) - def lifted_func_0() -> R.Tuple: + def main_inner() -> R.Tuple: y = R.print(format="Wow!") return y @R.function(pure=False) def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): - inner = Expected.lifted_func_0 + inner = Expected.main_inner gv1 = inner() return x @@ -339,5 +351,58 @@ def inner() -> R.Tuple: _check_save_roundtrip(after) +def test_lambda_function_with_same_name_as_global(): + """Lifted lambda names may not conflict with previous names + + Like `test_basic`, but the module has an existing function + `main_inner`, which has the same name as the LambdaLift's first + choice of name for the hoisted function. + """ + + @I.ir_module + class Before: + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function + def main_inner(): + return R.tuple() + + @I.ir_module + class Expected: + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + inner = Expected.main_inner_0 + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function(private=True) + def main_inner_0( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + @R.function + def main_inner(): + return R.tuple() + + after = transform.LambdaLift()(Before) + assert_structural_equal(Expected, after) + + if __name__ == "__main__": tvm.testing.main()