From 6187d48377c5c3022332f706973d70baff3b5ef7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 7 Dec 2023 16:15:28 +0000 Subject: [PATCH 1/3] [Unity][Transform] Update LambdaLift to use name of lifted lambda Prior to this commit, the `LambdaLift` pass named each function as `"lifted_func_" + i`, in incremental order of occurrence. This provided unique names for each function, but could be difficult to read, or to refer to the lifted functions. This commit updates the naming scheme to use the location at which the lifted lambda occurs to generate a unique name for the new `GlobalVar`. --- src/relax/transform/lambda_lift.cc | 197 +++++++++++++++++- .../relax/test_transform_lambda_lift.py | 44 ++-- 2 files changed, 218 insertions(+), 23 deletions(-) diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index e3ed24cd9ed7..2b09abaae6be 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -34,6 +34,178 @@ namespace tvm { namespace relax { +namespace { + +/* \brief Collect names of functions to be lifted out */ +class LambdaNameCollector : ExprVisitor { + public: + static std::unordered_map Collect(const IRModule& mod) { + LambdaNameCollector visitor; + + for (const auto& [gvar, base_func] : mod->functions) { + visitor.previous_global_vars_.insert(gvar->name_hint); + } + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + visitor.name_stack_.push_back(gvar->name_hint); + visitor(func.value()); + visitor.name_stack_.pop_back(); + } + } + + return visitor.Finalize(); + } + + private: + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override { + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + String public_name = opt.value(); + + // If a kGlobalSymbol exists, we must use the name exactly as it + // appears, with no modifications. Because these errors would + // be raised from deep within an optimization pipeline, but + // depends on small annotation changes from a user's initial + // model definition, they are intentionally verbose to + // (hopefully) provide sufficient context to a user encountering + // the error. + CHECK(!previous_global_vars_.count(public_name)) + << "Function " << name_stack_.front() << " contains a lambda with kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name << "\". " + << "However, the module already contains a GlobalVar with this name. " + << "If present, the kGlobalSymbol attribute must match the name of the GlobalVar, " + << "and GlobalVar names must be unique across an IRModule. " + << "Lifting the " << public_name << " function out of " << name_stack_.front() + << " would require violating one of these two conditions."; + + auto it = new_public_names_.find(public_name); + CHECK(it == new_public_names_.end()) + << "Function " << name_stack_.front() << " contains a lambda with kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name << "\". " + << "However, the function " << it->second.front() + << " also contains a lambda with the same value for kGlobalSymbol. " + << "If present, the kGlobalSymbol attribute must match the name of the GlobalVar, " + << "and GlobalVar names must be unique across an IRModule. " + << "Lifting the " << public_name << " function out of both " << name_stack_.front() + << " and " << it->second.front() + << " would require violating one of these two conditions."; + + new_public_names_.insert({public_name, name_stack_}); + lifted_with_global_symbol_.insert({func, public_name}); + } + + name_stack_.push_back(binding->var->name_hint()); + lambda_location_.insert({func, name_stack_}); + ExprVisitor::VisitBinding_(binding, func); + name_stack_.pop_back(); + } + + // De-duplication of collected names + std::unordered_map Finalize() const { + // The functions which still must be assigned a name + std::unordered_map> remaining_to_name = lambda_location_; + + // Collecting the functions that now have a name. + std::unordered_map lifted_names; + + // A lookup for names that are unavailable for use. + std::unordered_set unavailable_names = previous_global_vars_; + + // A helper function to check de-duplicate names + auto use_if_unique = [&](const auto& generate_proposed_name) { + if (remaining_to_name.empty()) { + return; + } + + std::unordered_map new_names; + for (const auto& [func, location] : remaining_to_name) { + if (Optional opt_proposed_name = generate_proposed_name(func, location)) { + auto proposed_name = opt_proposed_name.value(); + + if (unavailable_names.count(proposed_name)) { + // The name is already used, either from a GlobalVar, or + // from a previous round of attempted names. + } else if (auto it = new_names.find(proposed_name); it != new_names.end()) { + // The name is not unique within the current attempt. Mark + // the function as nullptr to previous any use of this name + it->second = nullptr; + } else { + // The name is unique so far. Track it for use. + new_names.insert({proposed_name, func}); + } + } + } + + for (const auto& [name, func] : new_names) { + if (func) { + lifted_names.insert({func, name}); + remaining_to_name.erase(func); + } + } + }; + + // 1. Start with any publicly explosed names from kGlobalSymbol + use_if_unique([&](const auto& func, const auto&) -> Optional { + if (auto it = lifted_with_global_symbol_.find(func); it != lifted_with_global_symbol_.end()) { + return it->second; + } else { + return NullOpt; + } + }); + + // 2. Try concatenating the name of the relax variable with the + // name of the function that contains it. + use_if_unique([&](const auto&, const auto& location) -> String { + std::stringstream stream; + stream << location.front() << "_" << location.back(); + return stream.str(); + }); + + // 3. Try concatenating the entire path together. Don't include + // paths of length 2, as they would already be attempted earlier. + use_if_unique([&](const auto&, const auto& location) -> Optional { + if (location.size() == 2) return NullOpt; + + std::stringstream stream; + bool is_first = true; + for (const auto& loc : location) { + if (is_first) { + is_first = false; + } else { + stream << "_"; + } + stream << loc; + } + return String(stream.str()); + }); + + // 4. Fallback. Count the number of times a relax variable with + // that name was used. + std::unordered_map usage_count; + use_if_unique([&](const auto&, const auto& location) -> String { + std::stringstream stream; + stream << location.front() << "_" << location.back(); + int usage = usage_count[stream.str()]++; + stream << "_" << usage; + + return stream.str(); + }); + + ICHECK(remaining_to_name.empty()) + << "Fallback failed to make unique names for all lifted lambda functions"; + + return lifted_names; + } + + Array name_stack_; + std::unordered_set previous_global_vars_; + std::unordered_map> new_public_names_; + std::unordered_map lifted_with_global_symbol_; + std::unordered_map> lambda_location_; +}; + +} // namespace + /* The goal of this class is to lift out any nested functions into top-level * functions. * @@ -42,17 +214,19 @@ namespace relax { */ class LambdaLifter : public ExprMutator { public: - explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = module; } + explicit LambdaLifter(const IRModule& module) + : ExprMutator(module), mod_(module), lifted_names_(LambdaNameCollector::Collect(module)) {} using ExprMutator::VisitExpr_; void VisitBinding_(const VarBindingNode* binding) final { - bool is_lambda = false; - if (binding->value->IsInstance()) { - is_lambda = true; + bool is_lambda = binding->value->IsInstance(); + if (is_lambda) { recur_vars_.push_back(binding->var); } + Expr new_value = this->VisitExpr(binding->value); + if (new_value->struct_info_.defined() && !new_value->struct_info_.same_as(binding->var->struct_info_)) { binding->var->struct_info_ = GetStructInfo(new_value); @@ -136,8 +310,15 @@ class LambdaLifter : public ExprMutator { Expr VisitExpr_(const FunctionNode* func_node) final { auto func = GetRef(func_node); - // TODO(@yongwww): consider appending inner func name into the lifted func name - String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); + String lift_func_name = [&]() { + auto it = lifted_names_.find(func_node); + ICHECK(it != lifted_names_.end()) + << "InternalError: " + << "Found lambda function during mutation step, " + << "but it wasn't found during the earlier name-generation step."; + return it->second; + }(); + auto global = GlobalVar(lift_func_name); Array free_vars = FreeVars(func); Array captured_vars; @@ -309,7 +490,9 @@ class LambdaLifter : public ExprMutator { std::unordered_map lambda_map_; Array recur_vars_; IRModule mod_; - size_t lift_func_num_ = 0; + + std::unordered_map lifted_names_; + /*! \brief Cache ops that would be used later to reduce lookup overhead. */ const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index d67248417173..830db5113a38 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -19,7 +19,7 @@ import tvm.testing from tvm import relax import tvm.script -from tvm.script import relax as R, tir as T +from tvm.script import relax as R, tir as T, ir as I from tvm.relax import transform from tvm.ir.base import assert_structural_equal @@ -39,11 +39,13 @@ def _check_save_roundtrip(x): def test_basic(): + """Functions can be listed from local bindings to the IRModule""" + # the target IRModule @tvm.script.ir_module class Expected: @R.function(private=True) - def lifted_func_0( + def main_inner( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): s: R.Tensor((10, 5), "float32") = R.add(x2, y2) @@ -53,7 +55,7 @@ def lifted_func_0( def main( x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): - inner = Expected.lifted_func_0 + inner = Expected.main_inner gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) return gv1 @@ -83,6 +85,8 @@ def inner( def test_closure(): + """Lifting functions may require producing closures""" + # the expected IRModule @tvm.script.ir_module class Expected: @@ -90,7 +94,7 @@ class Expected: def main( x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): - outer_func = Expected.lifted_func_0 + outer_func = Expected.main_outer_func in_call = outer_func(x) res = R.invoke_pure_closure( in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32")) @@ -98,13 +102,13 @@ def main( return res @R.function(private=True) - def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): + def main_inner_func(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1) return r_1 @R.function(private=True) - def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object: - inner_func = R.make_closure(Expected.lifted_func_1, (y,)) + def main_outer_func(y: R.Tensor((2, 3), "float32")) -> R.Object: + inner_func = R.make_closure(Expected.main_inner_func, (y,)) return inner_func # IRModule to perform Lambda Lifting @@ -137,11 +141,13 @@ def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): def test_recursive(): + """The lifted function may be recursively defined""" + # the expected IRModule @tvm.script.ir_module class Expected: @R.function(private=True) - def lifted_func_0( + def main_while_loop( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): cond: R.Tensor((), "bool") = R.call_pure_packed( @@ -151,7 +157,7 @@ def lifted_func_0( if cond: new_i: R.Tensor((), "int32") = R.add(i, c) new_s: R.Tensor((2, 3), "float32") = R.add(s, x) - new_r = Expected.lifted_func_0(new_i, new_s, x) + new_r = Expected.main_while_loop(new_i, new_s, x) r = new_r else: r = s @@ -159,7 +165,7 @@ def lifted_func_0( @R.function def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), dtype="float32"): - while_loop = R.make_closure(Expected.lifted_func_0, (x,)) + while_loop = R.make_closure(Expected.main_while_loop, (x,)) gv: R.Tensor((2, 3), dtype="float32") = R.invoke_pure_closure( while_loop, (R.const(0), x), @@ -205,6 +211,12 @@ def while_loop( def test_multi_func(): + """Lifting may be required for multiple top-level functions + + De-duplication of GlobalVar names at the IRModule is done by + appending the name of the function from which they were lifted. + """ + # expected IRModule @tvm.script.ir_module class Expected: @@ -212,7 +224,7 @@ class Expected: def glob_func_1( x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") ) -> R.Tensor(None, "float32", ndim=2): - inner = Expected.lifted_func_0 + inner = Expected.glob_func_1_inner gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) return gv1 @@ -220,19 +232,19 @@ def glob_func_1( def glob_func_2( x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), "float32") ) -> R.Tensor(None, "float32", ndim=2): - inner = Expected.lifted_func_1 + inner = Expected.glob_func_2_inner gv11: R.Tensor((10, 5), "float32") = inner(x11, y11) return gv11 @R.function(private=True) - def lifted_func_0( + def glob_func_1_inner( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): s: R.Tensor((10, 5), "float32") = R.add(x2, y2) return s @R.function(private=True) - def lifted_func_1( + def glob_func_2_inner( x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): s1: R.Tensor((10, 5), "float32") = R.add(x21, y21) @@ -309,13 +321,13 @@ def test_impure_function(): @tvm.script.ir_module class Expected: @R.function(pure=False, private=True) - def lifted_func_0() -> R.Tuple: + def main_inner() -> R.Tuple: y = R.print(format="Wow!") return y @R.function(pure=False) def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): - inner = Expected.lifted_func_0 + inner = Expected.main_inner gv1 = inner() return x From 63de59357e6520a168a1e61928798f02ac9bbcde Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 9 Jan 2024 16:59:02 +0000 Subject: [PATCH 2/3] Update variables names and comments for unique function naming --- src/relax/transform/lambda_lift.cc | 31 +++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 2b09abaae6be..c7caeab05596 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -111,15 +111,32 @@ class LambdaNameCollector : ExprVisitor { // A lookup for names that are unavailable for use. std::unordered_set unavailable_names = previous_global_vars_; - // A helper function to check de-duplicate names - auto use_if_unique = [&](const auto& generate_proposed_name) { + // A helper function to generate de-duplicated names. The + // `proposed_name_generation_func` should be a function with + // signature: + // + // Optional func(const FunctionNode*, const Array&) + // + // The first argument will be the lambda function being lifted. + // The second argument will be the nested location where that + // lambda function was found. The function should return the + // proposed name for the lifted lambda function. The proposed + // name will be accepted if it does not conflict with any previous + // names, and is unique for all lambda functions being lifted. + // + // This helper function is used to apply several different schemes + // to generate the name of the lifted lambda function. The + // overall goal is to provide names that are unique (required by + // IRModule), deterministic (required for unit testing), and + // human-readable. + auto attempt_name_generation = [&](const auto& proposed_name_generation_func) { if (remaining_to_name.empty()) { return; } std::unordered_map new_names; for (const auto& [func, location] : remaining_to_name) { - if (Optional opt_proposed_name = generate_proposed_name(func, location)) { + if (Optional opt_proposed_name = proposed_name_generation_func(func, location)) { auto proposed_name = opt_proposed_name.value(); if (unavailable_names.count(proposed_name)) { @@ -145,7 +162,7 @@ class LambdaNameCollector : ExprVisitor { }; // 1. Start with any publicly explosed names from kGlobalSymbol - use_if_unique([&](const auto& func, const auto&) -> Optional { + attempt_name_generation([&](const FunctionNode* func, const auto&) -> Optional { if (auto it = lifted_with_global_symbol_.find(func); it != lifted_with_global_symbol_.end()) { return it->second; } else { @@ -155,7 +172,7 @@ class LambdaNameCollector : ExprVisitor { // 2. Try concatenating the name of the relax variable with the // name of the function that contains it. - use_if_unique([&](const auto&, const auto& location) -> String { + attempt_name_generation([&](const FunctionNode*, const auto& location) -> String { std::stringstream stream; stream << location.front() << "_" << location.back(); return stream.str(); @@ -163,7 +180,7 @@ class LambdaNameCollector : ExprVisitor { // 3. Try concatenating the entire path together. Don't include // paths of length 2, as they would already be attempted earlier. - use_if_unique([&](const auto&, const auto& location) -> Optional { + attempt_name_generation([&](const FunctionNode*, const auto& location) -> Optional { if (location.size() == 2) return NullOpt; std::stringstream stream; @@ -182,7 +199,7 @@ class LambdaNameCollector : ExprVisitor { // 4. Fallback. Count the number of times a relax variable with // that name was used. std::unordered_map usage_count; - use_if_unique([&](const auto&, const auto& location) -> String { + attempt_name_generation([&](const FunctionNode*, const auto& location) -> String { std::stringstream stream; stream << location.front() << "_" << location.back(); int usage = usage_count[stream.str()]++; From 128ccab270f3b3eff47bbc9a4809533b1aa807fc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 9 Jan 2024 18:10:12 +0000 Subject: [PATCH 3/3] Add unit test for conflicting name --- .../relax/test_transform_lambda_lift.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index 830db5113a38..8f3daa06e200 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -351,5 +351,58 @@ def inner() -> R.Tuple: _check_save_roundtrip(after) +def test_lambda_function_with_same_name_as_global(): + """Lifted lambda names may not conflict with previous names + + Like `test_basic`, but the module has an existing function + `main_inner`, which has the same name as the LambdaLift's first + choice of name for the hoisted function. + """ + + @I.ir_module + class Before: + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function + def main_inner(): + return R.tuple() + + @I.ir_module + class Expected: + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + inner = Expected.main_inner_0 + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function(private=True) + def main_inner_0( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + @R.function + def main_inner(): + return R.tuple() + + after = transform.LambdaLift()(Before) + assert_structural_equal(Expected, after) + + if __name__ == "__main__": tvm.testing.main()