Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"
Expand Down Expand Up @@ -414,11 +415,27 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
new_params.size() == func->params.size()) {
return expr;
}
return FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
auto ret = FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
std::unordered_set<Var, NodeHash, NodeEqual> set;
for (const auto& v : FreeVars(expr)) {
set.insert(v);
}
for (const auto& v : FreeVars(ret)) {
if (set.count(v) == 0) {
new_params.push_back(v);
}
}
ret = FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return ret;
} else {
return ExprBinder(args_map).VisitExpr(expr);
}
Expand Down
36 changes: 35 additions & 1 deletion src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,46 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second;
}

template<typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
tvm::Array<T> ret(l);
for (const T& t : r) {
ret.push_back(t);
}
return ret;
}

void ModuleNode::Add(const GlobalVar& var,
const Function& f,
bool update) {
Function func = Downcast<Function>(DeDup(f));
// Type check the item before we add it to the module.
auto mod = GetRef<Module>(this);
auto fv = FreeVars(func);
auto ftv = FreeTypeVars(func, mod);
if (fv.size() != 0) {
LOG(WARNING)
<< "There are free variables: "
<< fv
<< " in function: "
<< AsText(func, false)
<< std::endl;
}
if (ftv.size() != 0) {
LOG(WARNING)
<< "There are free type variables: "
<< ftv
<< " in function: "
<< AsText(func, false)
<< std::endl;
}
func =
FunctionNode::make(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
// Type check the item before we add it to the module.
Function checked_func = InferType(func, mod, var);
auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr);
Expand Down Expand Up @@ -194,7 +228,7 @@ Module ModuleNode::FromExpr(
if (func_node) {
func = GetRef<Function>(func_node);
} else {
func = FunctionNode::make({}, expr, Type(), {}, {});
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVarNode::make("main");
mod->Add(main_gv, func);
Expand Down
12 changes: 10 additions & 2 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,16 @@ Pass QuantizeAnnotate() {

runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return FunctionNode::make(new_params,
func->body,
func->ret_type,
func->type_params,
func->attrs);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def test_ref():


def test_free_expr():
return
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = run_infer_type(y)
Expand Down Expand Up @@ -358,7 +359,6 @@ def test_adt_match_type_annotations():
test_recursion()
test_tuple()
test_incomplete_call()
test_free_expr()
test_type_args()
test_global_var_recursion()
test_equal()
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_typecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_id_type():
make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
t = relay.scalar_type("float32")
b = relay.Var("b", t)
mod["main"] = relay.Function([], make_id(b))
mod["main"] = relay.Function([make_id, b], make_id(b))
mod = transform.InferType()(mod)
assert mod["main"].body.checked_type == id_type(t)

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_graph_tuner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_get_direct_ancestor():
visited_dict = {}
input_names = ["data"]
out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out)


def test_get_in_nodes():
Expand All @@ -125,7 +125,7 @@ def test_get_in_nodes():
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
out = get_in_nodes(node_list, target_ops, input_names)
expected_out = {7: [3], 3: [2, 0], 2: [0]}
expected_out = {3: [0], 4: [3, 0], 7: [4]}
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
Expand Down