Skip to content

Commit 761b764

Browse files
init
lint update address comment
1 parent 882ae12 commit 761b764

File tree

5 files changed

+73
-14
lines changed

5 files changed

+73
-14
lines changed

src/relay/ir/expr_functor.cc

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,13 +18,14 @@
1818
*/
1919

2020
/*!
21-
* Copyright (c) 2018 by Contributors
21+
* Copyright (c) 2019 by Contributors
2222
* \file src/tvm/relay/expr_mutator.cc
2323
* \brief A wrapper around ExprFunctor which functionally updates the AST.
2424
*
2525
* ExprMutator uses memoization and self return in order to amortize
2626
* the cost of using functional updates.
2727
*/
28+
#include <tvm/relay/analysis.h>
2829
#include <tvm/relay/expr_functor.h>
2930
#include "type_functor.h"
3031

@@ -400,11 +401,27 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
400401
new_params.size() == func->params.size()) {
401402
return expr;
402403
}
403-
return FunctionNode::make(new_params,
404-
new_body,
405-
func->ret_type,
406-
func->type_params,
407-
func->attrs);
404+
auto ret = FunctionNode::make(new_params,
405+
new_body,
406+
func->ret_type,
407+
func->type_params,
408+
func->attrs);
409+
std::unordered_set<Var, NodeHash, NodeEqual> set;
410+
for (const auto& v : FreeVars(expr)) {
411+
set.insert(v);
412+
}
413+
for (const auto& v : FreeVars(ret)) {
414+
if (set.count(v) == 0) {
415+
new_params.push_back(v);
416+
}
417+
}
418+
ret = FunctionNode::make(new_params,
419+
new_body,
420+
func->ret_type,
421+
func->type_params,
422+
func->attrs);
423+
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
424+
return ret;
408425
} else {
409426
return ExprBinder(args_map).Mutate(expr);
410427
}

src/relay/ir/module.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,46 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
8888
return (*it).second;
8989
}
9090

91+
template<typename T>
92+
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
93+
tvm::Array<T> ret(l);
94+
for (const T& t : r) {
95+
ret.push_back(t);
96+
}
97+
return ret;
98+
}
99+
91100
void ModuleNode::Add(const GlobalVar& var,
92101
const Function& f,
93102
bool update) {
94103
Function func = Downcast<Function>(DeDup(f));
95104
// Type check the item before we add it to the module.
96105
auto mod = GetRef<Module>(this);
106+
auto fv = FreeVars(func);
107+
auto ftv = FreeTypeVars(func, mod);
108+
if (fv.size() != 0) {
109+
LOG(WARNING)
110+
<< "There are free variables: "
111+
<< fv
112+
<< " in function: "
113+
<< AsText(func, false)
114+
<< std::endl;
115+
}
116+
if (ftv.size() != 0) {
117+
LOG(WARNING)
118+
<< "There are free type variables: "
119+
<< ftv
120+
<< " in function: "
121+
<< AsText(func, false)
122+
<< std::endl;
123+
}
124+
func =
125+
FunctionNode::make(concat(func->params, fv),
126+
func->body,
127+
func->ret_type,
128+
concat(func->type_params, ftv),
129+
func->attrs);
130+
// Type check the item before we add it to the module.
97131
Function checked_func = InferType(func, mod, var);
98132
auto type = checked_func->checked_type();
99133
CHECK(type.as<IncompleteTypeNode>() == nullptr);
@@ -174,7 +208,7 @@ Module ModuleNode::FromExpr(
174208
if (func_node) {
175209
func = GetRef<Function>(func_node);
176210
} else {
177-
func = FunctionNode::make({}, expr, Type(), {}, {});
211+
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
178212
}
179213
mod->Add(mod->entry_func, func);
180214
return mod;

src/relay/pass/quantize.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,16 @@ Pass QuantizeAnnotate() {
674674

675675
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
676676
[=](Function f, Module m, PassContext pc) {
677-
return Downcast<Function>(
678-
ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
677+
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
678+
auto new_params = func->params;
679+
for (const auto& x : FreeVars(func)) {
680+
new_params.push_back(x);
681+
}
682+
return FunctionNode::make(new_params,
683+
func->body,
684+
func->ret_type,
685+
func->type_params,
686+
func->attrs);
679687
};
680688
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
681689
}

tests/python/relay/test_typecall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_id_type():
3939
make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
4040
t = relay.scalar_type("float32")
4141
b = relay.Var("b", t)
42-
mod[mod.entry_func] = relay.Function([], make_id(b))
42+
mod[mod.entry_func] = relay.Function([make_id, b], make_id(b))
4343
mod = transform.InferType()(mod)
4444
assert mod[mod.entry_func].body.checked_type == id_type(t)
4545

tests/python/unittest/test_graph_tuner_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_get_direct_ancestor():
106106
visited_dict = {}
107107
input_names = ["data"]
108108
out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
109-
assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
109+
assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out)
110110

111111

112112
def test_get_in_nodes():
@@ -125,7 +125,7 @@ def test_get_in_nodes():
125125
node_dict = {}
126126
expr2graph(net, target_ops, node_dict, node_list)
127127
out = get_in_nodes(node_list, target_ops, input_names)
128-
expected_out = {7: [3], 3: [2, 0], 2: [0]}
128+
expected_out = {3: [0], 4: [3, 0], 7: [4]}
129129
diff_set = set(out) ^ set(expected_out)
130130
if len(diff_set) != 0:
131131
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))

0 commit comments

Comments
 (0)