Skip to content

Commit 589659b

Browse files
MarisaKirisamewweic
authored andcommitted
save save save upstream lint remove bad changes fix build save save please the ci god Update src/relay/pass/partial_eval.cc Co-Authored-By: Wei Chen <[email protected]> save fix test ci is ANGRY fix rebase problem fix rebase add test save save comment
1 parent c1cebea commit 589659b

File tree

7 files changed

+624
-171
lines changed

7 files changed

+624
-171
lines changed

include/tvm/relay/pass.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
296296
* For example, this pass should turn `let a = 1 in 2` into `2`,
297297
* as the value of the expression does not depend on a.
298298
*
299-
* As another example, `let a = 1 in a` will be optimized into 1.
299+
* As another example, `let a = 1 in a` will be optimized into 1,
300+
* if the flag is turned on.
300301
*
301302
* \param e the expression to optimize.
303+
* \param inline_once whether or not to inline binding used one.
302304
*
303305
* \return the optimized expression.
304306
*/
305-
TVM_DLL Expr DeadCodeElimination(const Expr& e);
307+
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
306308

307309
/*!
308310
* \brief Fold constant expressions.
@@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
435437
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
436438
* As a side effect, code size will explode.
437439
*
438-
* \param e the expression,
440+
* \param e the expression
441+
* \param mod the module
439442
*
440443
* \return the optimized expression.
441444
*/
442-
TVM_DLL Expr PartialEval(const Expr& e);
445+
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
443446

444447
/*!
445448
* \brief Bind the free variables to a Relay expression.

include/tvm/relay/transform.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
356356
*
357357
* As another example, `let a = 1 in a` will be optimized into 1.
358358
*
359+
* \param inline_once whether or not to inline binding used one.
360+
*
359361
* \return the pass.
360362
*/
361-
TVM_DLL Pass DeadCodeElimination();
363+
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
362364

363365
/*!
364366
* \brief Fold constant expressions.

python/tvm/relay/ir_pass.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def well_formed(expr):
129129
130130
Parameters
131131
----------
132-
expr: tvm.relay.Expr
132+
expr : tvm.relay.Expr
133133
The input expression
134134
135135
Returns
@@ -175,7 +175,7 @@ def free_vars(expr):
175175
176176
Parameters
177177
----------
178-
expr: tvm.relay.Expr
178+
expr : tvm.relay.Expr
179179
The input expression
180180
181181
Returns
@@ -197,7 +197,7 @@ def bound_vars(expr):
197197
198198
Parameters
199199
----------
200-
expr: tvm.relay.Expr
200+
expr : tvm.relay.Expr
201201
The input expression
202202
203203
Returns
@@ -213,7 +213,7 @@ def all_vars(expr):
213213
214214
Parameters
215215
----------
216-
expr: tvm.relay.Expr
216+
expr : tvm.relay.Expr
217217
The input expression
218218
219219
Returns
@@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
229229
230230
Parameters
231231
----------
232-
expr: Union[tvm.relay.Expr,tvm.relay.Type]
232+
expr : Union[tvm.relay.Expr,tvm.relay.Type]
233233
The input expression/type
234-
mod: tvm.relay.Module, optional
234+
235+
mod : Optional[tvm.relay.Module]
235236
The global module
236237
237238
Returns
@@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
248249
249250
Parameters
250251
----------
251-
expr: Union[tvm.relay.Expr,tvm.relay.Type]
252+
expr : Union[tvm.relay.Expr,tvm.relay.Type]
252253
The input expression/type
253-
mod: tvm.relay.Module, optional
254+
255+
mod : Optional[tvm.relay.Module]
254256
The global module
255257
256258
Returns
@@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
267269
268270
Parameters
269271
----------
270-
expr: Union[tvm.relay.Expr,tvm.relay.Type]
272+
expr : Union[tvm.relay.Expr,tvm.relay.Type]
271273
The input expression/type
272-
mod: tvm.relay.Module, optional
274+
mod : Optional[tvm.relay.Module]
273275
The global module
274276
275277
Returns
@@ -286,12 +288,12 @@ def simplify_inference(expr):
286288
287289
Parameters
288290
----------
289-
e: tvm.relay.Expr
291+
expr : tvm.relay.Expr
290292
The input Expression
291293
292294
Returns
293295
-------
294-
result: tvm.relay.Expr
296+
result : tvm.relay.Expr
295297
An expression which is semantically equal to the input expression,
296298
but with some simplification
297299
"""
@@ -304,48 +306,50 @@ def canonicalize_ops(expr):
304306
305307
Parameters
306308
----------
307-
e: tvm.relay.Expr
309+
expr : tvm.relay.Expr
308310
The input Expression
309311
310312
Returns
311313
-------
312-
result: tvm.relay.Expr
314+
result : tvm.relay.Expr
313315
An expression without bias_add
314316
"""
315317
return _ir_pass.canonicalize_ops(expr)
316318

317319

318-
def dead_code_elimination(expr):
320+
def dead_code_elimination(expr, inline_once=False):
319321
""" Remove expressions which does not effect the program result (dead code).
320322
321323
Parameters
322324
----------
323-
e: tvm.relay.Expr
325+
expr : tvm.relay.Expr
324326
The input Expression
325327
328+
inline_once : Optional[Bool]
329+
Whether to inline binding that occur only once.
326330
Returns
327331
-------
328-
result: tvm.relay.Expr
332+
result : tvm.relay.Expr
329333
An expression which is semantically equal to the input expression,
330334
but with dead code removed.
331335
"""
332-
return _ir_pass.dead_code_elimination(expr)
336+
return _ir_pass.dead_code_elimination(expr, inline_once)
333337

334338

335339
def alpha_equal(lhs, rhs):
336340
"""Compare two Relay expr for structural equivalence (alpha equivalence).
337341
338342
Parameters
339343
----------
340-
lhs: tvm.relay.Expr
344+
lhs : tvm.relay.Expr
341345
One of the input Expression.
342346
343-
rhs: tvm.relay.Expr
347+
rhs : tvm.relay.Expr
344348
One of the input Expression.
345349
346350
Returns
347351
-------
348-
result: bool
352+
result : bool
349353
True iff lhs is alpha equal to rhs.
350354
"""
351355
return bool(_make._alpha_equal(lhs, rhs))
@@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
359363
360364
Parameters
361365
----------
362-
lhs: tvm.relay.Expr
366+
lhs : tvm.relay.Expr
363367
One of the input Expression.
364368
365-
rhs: tvm.relay.Expr
369+
rhs : tvm.relay.Expr
366370
One of the input Expression.
367371
368372
Returns
369373
-------
370-
result: bool
374+
result : bool
371375
True iff lhs is data-flow equivalent to rhs.
372376
"""
373377
return bool(_make._graph_equal(lhs, rhs))
@@ -378,12 +382,12 @@ def structural_hash(value):
378382
379383
Parameters
380384
----------
381-
expr: tvm.relay.Expr or tvm.relay.Type
385+
expr : Union[tvm.relay.Expr, tvm.relay.Type]
382386
The expression to hash.
383387
384388
Returns
385389
-------
386-
result: int
390+
result : int
387391
The hash value
388392
"""
389393
if isinstance(value, Expr):
@@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
544548
expr : tvm.relay.Expr
545549
The input expression.
546550
547-
mod: Optional[tvm.relay.Module]
551+
mod : Optional[tvm.relay.Module]
548552
The global module.
549553
550554
Returns
551555
-------
552-
expr: tvm.relay.Expr
556+
result : tvm.relay.Expr
553557
The output expression.
554558
"""
555559
return _ir_pass.to_a_normal_form(expr, mod)
@@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
563567
The input expression
564568
Returns
565569
-------
566-
expr : tvm.relay.Expr
570+
result : tvm.relay.Expr
567571
The output expression
568572
"""
569573
return _ir_pass.to_graph_normal_form(expr)
@@ -612,7 +616,7 @@ def get_total_mac_number(expr):
612616
613617
Returns
614618
-------
615-
ret : int64
619+
result : int64
616620
The number of MACs (multiply-accumulate) of a model
617621
"""
618622
return _ir_pass.GetTotalMacNumber(expr)
@@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
627631
expr : tvm.relay.Expr
628632
The input expression.
629633
630-
fskip: function
634+
fskip : function
631635
The callback function that decides whether an expression should be skipped.
632636
633637
Returns
634638
-------
635-
expr : tvm.relay.Expr
639+
result : tvm.relay.Expr
636640
The output expression.
637641
"""
638642
return _ir_pass.eliminate_common_subexpr(expr, fskip)
639643

640-
def partial_evaluate(expr):
644+
def partial_evaluate(expr, mod=None):
641645
"""
642646
Evaluate the static fragment of the code.
643647
@@ -646,12 +650,15 @@ def partial_evaluate(expr):
646650
expr : tvm.relay.Expr
647651
The input expression.
648652
653+
mod : Optional[tvm.relay.Module]
654+
The global module
655+
649656
Returns
650657
-------
651-
expr : tvm.relay.Expr
658+
result : tvm.relay.Expr
652659
The output expression.
653660
"""
654-
return _ir_pass.partial_evaluate(expr)
661+
return _ir_pass.partial_evaluate(expr, mod)
655662

656663
def unmatched_cases(match, mod=None):
657664
"""

src/relay/ir/expr.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")
220220

221221
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
222222
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
223-
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
224-
<< node->attrs << ", " << node->type_args << ")";
223+
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
224+
<< node->attrs << ", " << node->type_args << ")";
225225
});
226226

227227
Let LetNode::make(Var var, Expr value, Expr body) {
@@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
324324

325325
TVM_REGISTER_API("relay._expr.TempExprRealize")
326326
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
327-
return temp->Realize();
327+
return temp->Realize();
328328
});
329329

330330
} // namespace relay

src/relay/pass/dead_code.cc

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ namespace relay {
3838
// calculate the dependency graph from expression
3939
class CalcDep : private ExprVisitor {
4040
public:
41-
static Expr Eliminate(const Expr& e) {
41+
static Expr Eliminate(const Expr& e, bool inline_once) {
4242
CalcDep cd;
4343
cd.Calculate(e);
44-
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
44+
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
4545
return el(e);
4646
}
4747

@@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
117117
VarMap<Expr> expr_map_;
118118
VarMap<size_t> use_map_;
119119
VarSet letrec_set_;
120+
bool inline_once_;
120121
explicit Eliminator(const VarMap<Expr>& expr_map,
121122
const VarMap<size_t>& use_map,
122-
const VarSet& letrec_set) :
123-
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
123+
const VarSet& letrec_set,
124+
bool inline_once) :
125+
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
124126
friend CalcDep;
125127

126128
bool HasLet(const Var& v) {
127-
// TODO(@jroesch): MK fix me
128-
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
129+
switch (use_map_[v]) {
130+
case 0:
131+
return false;
132+
case 1:
133+
return letrec_set_.count(v) > 0 || !inline_once_;
134+
default:
135+
return true;
136+
}
129137
}
130138

131139
Expr VisitExpr_(const VarNode* op) final {
@@ -144,19 +152,19 @@ class CalcDep : private ExprVisitor {
144152
};
145153
};
146154

147-
Expr DeadCodeElimination(const Expr& e) {
148-
return CalcDep::Eliminate(e);
155+
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
156+
return CalcDep::Eliminate(e, inline_once);
149157
}
150158

151159
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
152160
.set_body_typed(DeadCodeElimination);
153161

154162
namespace transform {
155163

156-
Pass DeadCodeElimination() {
164+
Pass DeadCodeElimination(bool inline_once) {
157165
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
158166
[=](Function f, Module m, PassContext pc) {
159-
return Downcast<Function>(DeadCodeElimination(f));
167+
return Downcast<Function>(DeadCodeElimination(f, inline_once));
160168
};
161169
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
162170
}

0 commit comments

Comments
 (0)