diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 937cb8702952..76df48430592 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -119,7 +119,10 @@ struct BaseCollectInfo { Function func(params, body, GetStructInfo(tuple_var)); func = WithAttr(func, attr::kNumInput, Integer(0)); func = CopyWithNewVars(func); + func = BundleModelParams(func); func = Downcast(CanonicalizeBindings(func)); + func = Downcast(RemoveAllUnused(func)); + return func; } }; @@ -725,11 +728,12 @@ std::vector> GetTargetFunctions( target_functions.push_back({gvar.value(), func.value()}); } } else { - // Get all the functions that have the `num_input` attribute. + // Get all the functions that have the `num_input` attribute, and + // are not already the result of `LiftTransformParams`. for (const auto& [gvar, func] : mod->functions) { if (func->IsInstance()) { auto opt_num_input = func->GetAttr(attr::kNumInput); - if (opt_num_input) { + if (opt_num_input && !ends_with(gvar->name_hint, "transform_params")) { target_functions.emplace_back(gvar, Downcast(func)); } } @@ -748,7 +752,6 @@ namespace transform { Pass PartitionTransformParams(Variant> shared_transform) { auto pass_func = [=](IRModule mod, PassContext pc) { - IRModule updates; std::optional global_collect_info; CHECK(shared_transform.defined()) << "shared_transform is not defined"; @@ -772,24 +775,41 @@ Pass PartitionTransformParams(Variant> shared_transform) { local_collect_info[gvar] = info; } + IRModule updated_runtime_functions; + for (const auto& [gvar, info] : local_collect_info) { auto new_runtime_func = info.MakeRuntimeFunction(); - updates->Add(gvar, new_runtime_func); + updated_runtime_functions->Add(gvar, new_runtime_func); } + Map lifted_transform_functions; if (global_collect_info.has_value()) { auto global_transform = global_collect_info.value().MakeCompileTimeFunc(); - updates->Add(GlobalVar("transform_params"), global_transform); + lifted_transform_functions.Set("transform_params", global_transform); } else { for (const auto& [gvar, info] : local_collect_info) { // transform_params is emitted for each function if global lifting is not enabled - updates->Add(GlobalVar(gvar->name_hint + "_transform_params"), - info.MakeCompileTimeFunction()); + lifted_transform_functions.Set(gvar->name_hint + "_transform_params", + info.MakeCompileTimeFunction()); } } - if (updates->functions.size()) { - mod.CopyOnWrite()->Update(updates); + if (updated_runtime_functions->functions.size() || lifted_transform_functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updated_runtime_functions); + + for (auto [name, transform] : lifted_transform_functions) { + if (auto opt = write_ptr->global_var_map_.Get(name)) { + auto old_gvar = opt.value(); + auto old_transform = Downcast(write_ptr->Lookup(old_gvar)); + write_ptr->Remove(old_gvar); + + transform = ComposeFunctions(old_transform, transform); + } + GlobalVar new_gvar(name); + UpdateStructInfo(new_gvar, GetStructInfo(transform)); + write_ptr->Add(new_gvar, transform); + } } return mod; @@ -817,7 +837,6 @@ Pass LiftTransformParams(Variant> shared_transform) { std::string func_name = gvar->name_hint; if (ends_with(func_name, "transform_params")) { func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); - func = BundleModelParams(func); if (pc->GetConfig(kLiftTransformConsumeParams).value_or(Bool(false))) { func = Downcast(ConsumeBundledParams()(func)); } diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc index c0fde3bd4cb9..19e93bbc0c0e 100644 --- a/src/relax/transform/utils.cc +++ b/src/relax/transform/utils.cc @@ -19,6 +19,8 @@ #include "utils.h" +#include + namespace tvm { namespace relax { @@ -41,5 +43,54 @@ bool IsNestedTensor(const StructInfo& sinfo) { bool IsNestedTensor(const Expr& expr) { return IsNestedTensor(GetStructInfo(expr)); } +Function ComposeFunctions(Function func_a, Function func_b) { + Array bindings; + + Var func_a_output("func_a_output", func_a->ret_struct_info); + + bindings.push_back(VarBinding(func_a_output, func_a->body)); + + auto func_a_outputs = [&]() -> Array { + if (auto func_a_output_tuple = func_a->ret_struct_info.as()) { + Array outputs; + for (size_t i = 0; i < func_a_output_tuple->fields.size(); i++) { + outputs.push_back(TupleGetItem(func_a_output, i)); + } + return outputs; + } else { + return {func_a_output}; + } + }(); + + if (func_b->params.size() == 1 && func_b->params[0]->struct_info_.as()) { + // Special case where the output of the first function is a tuple + // that should be provided as-is to the second function, and + // should not be unpacked into individual elements. + auto param = func_b->params[0]; + bindings.push_back(MatchCast(param, func_a_output, GetStructInfo(param))); + } else { + CHECK_EQ(func_a_outputs.size(), func_b->params.size()) + << "ValueError: " + << "Cannot compose functions together. " + << "First function produces " << func_a_outputs.size() << " values, " + << "but second function expects " << func_b->params.size() << " parameters as input"; + for (size_t i = 0; i < func_a_outputs.size(); i++) { + auto param = func_b->params[i]; + bindings.push_back(MatchCast(param, func_a_outputs[i], GetStructInfo(param))); + } + } + + auto new_body = SeqExpr({BindingBlock(bindings)}, func_b->body); + + auto new_function = Function(func_a->params, new_body, func_b->ret_struct_info, + func_a->is_pure && func_b->is_pure, func_a->attrs); + + new_function = CopyWithNewVars(new_function); + new_function = Downcast(CanonicalizeBindings(new_function)); + new_function = Downcast(RemoveAllUnused(new_function)); + + return new_function; +} + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 932dca30a110..55e355b4bac2 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -437,6 +437,20 @@ Expr CanonicalizeBindings(Expr expr); */ Function BundleModelParams(const Function& func, Optional param_tuple_name = NullOpt); +/*! \brief Compose two functions + * + * Given two functions `func_a` and `func_b`, produce `func_c` such + * that `func_c(x)` is equivalent to `func_b(func_a(x))`. + * + * If the output if `func_a` is not usable as the input of `func_b`, + * an error will be raised. + * + * \param func_a The first function to be composed. + * \param func_b The second function to be composed. + * \return The composed function + */ +TVM_DLL Function ComposeFunctions(Function func_a, Function func_b); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 508664f1ef54..90f2050f7898 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -112,7 +112,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -185,7 +185,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -290,18 +290,15 @@ def main( @R.function def main_transform_params( - params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) + params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") ): R.func_attr({"num_input": 0}) with R.dataflow(): - lv = params[0] - lv0 = (lv,) - lv1 = (lv0,) - lv2 = params[0] - lv3 = params[0] - gv = (lv2, lv3) + l3 = params[0] + w1 = params[0] + gv = (w1, l3) R.output(gv) return gv @@ -340,24 +337,14 @@ def main_transform_params( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((), dtype="bool"), - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((), dtype="bool"), ): R.func_attr({"num_input": 0}) - with R.dataflow(): - lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] - lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] - lv2: R.Tensor((), dtype="bool") = params[2] - gv: R.Tuple( - R.Tensor((16, 16, 3, 3), dtype="float32"), - R.Tensor((16, 16, 3, 3), dtype="float32"), - R.Tensor((), dtype="bool"), - ) = (lv, lv1, lv2) - R.output(gv) - return gv + return params @R.function def main( @@ -434,7 +421,7 @@ def func1( @R.function def func1_transform_params( - params: R.Tuple(R.Tensor((256, 256), dtype="float32")) + params: R.Tuple(R.Tensor((256, 256), dtype="float32")), ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -457,7 +444,7 @@ def func2( @R.function def func2_transform_params( - params: R.Tuple(R.Tensor((128, 256), dtype="float32")) + params: R.Tuple(R.Tensor((128, 256), dtype="float32")), ) -> R.Tuple(R.Tensor((256, 128), dtype="float32")): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -531,7 +518,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -769,7 +756,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -884,7 +871,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -979,7 +966,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1103,7 +1090,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1226,7 +1213,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1322,7 +1309,7 @@ def transform_params( params: R.Tuple( R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32"), - ) + ), ): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1395,7 +1382,7 @@ def func1( @R.function def func1_transform_params( - params: R.Tuple(R.Tensor((256, 256), dtype="float32")) + params: R.Tuple(R.Tensor((256, 256), dtype="float32")), ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): R.func_attr({"num_input": 0}) with R.dataflow(): @@ -1426,9 +1413,6 @@ class Expected: @R.function def main_transform_params(params: R.Tuple) -> R.Tuple: R.func_attr({"num_input": 0}) - with R.dataflow(): - gv: R.Tuple = R.tuple() - R.output() # All instance of the empty tuple are normalized to be # in-line. return R.tuple() @@ -1492,9 +1476,6 @@ def zeros(var_T_full: T.handle): @R.function def main_transform_params(params: R.Tuple) -> R.Tuple: R.func_attr({"num_input": 0}) - with R.dataflow(): - gv: R.Tuple = R.tuple() - R.output() return R.tuple() @R.function @@ -1579,7 +1560,7 @@ def main( @R.function def main_transform_params( - params: R.Tuple(R.Tensor([16, 16], "int32"), R.Shape(["slice_index"])) + params: R.Tuple(R.Tensor([16, 16], "int32"), R.Shape(["slice_index"])), ): R.func_attr({"num_input": 0}) slice_index = T.int64() @@ -1643,7 +1624,7 @@ def main_transform_params( params: R.Tuple( R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3, 3), dtype="float32"), - ) + ), ) -> R.Tuple( R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3, 3), dtype="float32") ): @@ -1821,5 +1802,75 @@ def main_transform_params(params: R.Tuple([R.Tensor([16], "int32")])): tvm.ir.assert_structural_equal(after, Expected) +@pytest.mark.parametrize("shared_transform", [True, False]) +def test_lift_transform_is_idempotent(shared_transform): + """Multiple applicates of LiftTransformParams are allowed""" + + @I.ir_module + class Module: + @R.function + def main( + state: R.Tensor(["batch_size", 4096], "float16"), + base_weights: R.Tensor([4096, 4096], "float16"), + lora_A: R.Tensor([4096, "lora_rank"], "float16"), + lora_B: R.Tensor(["lora_rank", 4096], "float16"), + ): + R.func_attr({"num_input": 1}) + folded_weights = base_weights + R.matmul(lora_A, lora_B) + output = R.matmul(state, folded_weights) + return output + + transform = relax.transform.LiftTransformParams(shared_transform=shared_transform) + + AfterOneRound = transform(Module) + assert len(AfterOneRound.functions) == 2 + + AfterTwoRounds = transform(AfterOneRound) + assert len(AfterTwoRounds.functions) == 2 + + tvm.ir.assert_structural_equal(AfterOneRound, AfterTwoRounds) + + +def test_lift_transform_when_one_already_exists(): + """If the module already contains `transform_params`, the + functions are composed together""" + + @I.ir_module + class Module: + @R.function + def main( + state: R.Tensor(["batch_size", 4096], "float16"), + base_weights: R.Tensor([4096, 4096], "float16"), + lora_A: R.Tensor([4096, "lora_rank"], "float16"), + lora_B: R.Tensor(["lora_rank", 4096], "float16"), + ): + R.func_attr({"num_input": 1}) + folded_weights = base_weights + R.matmul(lora_A, lora_B) + output = R.matmul(state, folded_weights) + return output + + @R.function + def main_transform_params( + model_params: R.Tuple( + R.Tensor([4096, 4096], "float16"), + R.Tensor([4096, "lora_rank"], "float16"), + R.Tensor(["lora_rank", 4096], "float16"), + ), + ): + R.func_attr({"num_input": 0}) + return model_params + + transform = relax.transform.LiftTransformParams(shared_transform=False) + after_lift_with_previous_identity_function = transform(Module) + + del Module["main_transform_params"] + after_lift_without_previous_identity_function = transform(Module) + + tvm.ir.assert_structural_equal( + after_lift_without_previous_identity_function, + after_lift_with_previous_identity_function, + ) + + if __name__ == "__main__": tvm.testing.main()