Skip to content

Commit 4440228

Browse files
init
quickfix
1 parent 3116eee commit 4440228

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

python/tvm/relay/transform.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,30 @@ def __init__(self,
248248
passes, opt_level, name, required)
249249

250250

251+
def infer_type(expr, mod=None):
252+
"""Infer the type of an expr.
253+
Adding Function into a Module will change it's binding,
254+
and some passes need type inference to work without binding modification.
255+
However, InferType() work by putting stuff into a Module, thus changing all the binding.
256+
257+
This is an escape patch that allow type inference without binding changing.
258+
259+
Parameters
260+
----------
261+
expr : tvm.relay.Expr
262+
The input expression.
263+
264+
mod : Optional[tvm.relay.Module]
265+
The input module
266+
267+
Returns
268+
-------
269+
ret : tvm.relay.Expr
270+
The output expression.
271+
"""
272+
return _transform.infer_type(expr, mod)
273+
274+
251275
def InferType():
252276
"""Infer the type of an expr.
253277

src/relay/pass/type_infer.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,9 @@ Function InferType(const Function& func,
824824
return Downcast<Function>(func_ret);
825825
}
826826

827+
TVM_REGISTER_API("relay._transform.infer_type")
828+
.set_body_typed<Expr(Expr, Module)>([](Expr l, Module r) { return InferType(l, r); });
829+
827830
namespace transform {
828831

829832
Pass InferType() {

0 commit comments

Comments
 (0)