Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,8 +705,28 @@ std::vector<std::pair<GlobalVar, Function>> GetTargetFunctions(
std::vector<std::pair<GlobalVar, Function>> target_functions;
if (shared_transform.as<Array<String>>().value_or(Array<String>{}).size()) {
for (const auto& name : shared_transform.as<Array<String>>().value()) {
auto gvar = mod->GetGlobalVar(name);
target_functions.push_back({gvar, Downcast<Function>(mod->Lookup(gvar))});
auto gvar = mod->global_var_map_.Get(name);
CHECK(gvar) << "When LiftTransformParams is called with a list of function names, "
<< "all function names must occur within the IRModule. "
<< "However, the IRModule does not contain a function names '" << name << "'";

auto base_func = mod->functions.Get(gvar.value());
ICHECK(base_func) << "Ill-formed IRModule. "
<< "The map from name to GlobalVar found " << gvar.value()
<< " for the function name '" << name
<< "', but this GlobalVar does not appear in the IRModule";

auto func = base_func.as<Function>();
CHECK(func) << "When LiftTransformParams is called with a list of function names, "
<< "only functions in the list must be relax functions. "
<< "However, the function " << name << " is of type " << base_func->GetTypeKey();
CHECK(func.value()->GetAttr<Integer>(attr::kNumInput))
<< "When LiftTransformParams is called with a list of function names, "
<< "all functions in the list must have the kNumInput ('" << attr::kNumInput
<< "') attribute. "
<< "However, the function " << name << " does not have the kNumInput attribute";

target_functions.push_back({gvar.value(), func.value()});
}
} else {
// Get all the functions that have the `num_input` attribute.
Expand Down