Skip to content

Commit d8678a6

Browse files
authored
[TIR] CSE pass : Restrict the equivalence to be decided by a normal form - avoids comparison of terms (#11574)
The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test. However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**. That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them. **This PR essentially does the following:** - When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**. That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc). It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs. - Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms. As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that. To run it ON, simply do: `with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):` before calling `build()` - When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in #10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc. - Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in #10544 but which did not make it through. - Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from #10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
1 parent 236eea0 commit d8678a6

File tree

8 files changed

+409
-123
lines changed

8 files changed

+409
-123
lines changed

include/tvm/tir/transform.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,9 +470,10 @@ TVM_DLL Pass LowerVtcmAlloc();
470470
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
471471
* which introduces let-in bindings for duplicated sub-expressions.
472472
* \param enable_cse_tir Whether common subexpression elimination is enabled.
473+
* \param identify_equiv_terms Whether equivalent terms should be identified.
473474
* \return The pass.
474475
*/
475-
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true);
476+
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);
476477

477478
/*!
478479
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and

python/tvm/tir/transform/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,15 @@ def BF16TypeLowering():
324324
return _ffi_api.BF16TypeLowering() # type: ignore
325325

326326

327-
def CommonSubexprElimTIR(enable_cse_tir: bool = True):
327+
def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
328328
"""Replace redundant computations by new variables.
329329
330330
Returns
331331
-------
332332
fpass : tvm.transform.Pass
333333
The result pass
334334
"""
335-
return _ffi_api.CommonSubexprElimTIR(enable_cse_tir) # type: ignore
335+
return _ffi_api.CommonSubexprElimTIR(enable_cse_tir, identify_equiv_terms) # type: ignore
336336

337337

338338
def RewriteUnsafeSelect():

src/driver/driver_api.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
4545
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
4646
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
4747
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
48+
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
4849
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
4950
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
5051
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
@@ -198,6 +199,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
198199
bool instrument_bound_checkers =
199200
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
200201
bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
202+
bool enable_equiv_terms_in_cse_tir =
203+
pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value();
201204

202205
// Get any user-added passes
203206
Array<Array<ObjectRef>> add_lower_pass =
@@ -289,7 +292,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
289292
pass_list.push_back(tir::transform::InstrumentBoundCheckers());
290293
}
291294

292-
pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir));
295+
pass_list.push_back(
296+
tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir));
293297

294298
return pass_list;
295299
}

src/tir/transforms/common_subexpr_elim.cc

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ namespace tir {
6060
to collect them for the CSE pass, but we also won't even want to collect computations
6161
that contain them.
6262
The reason is that reusing such computations would change the semantics of the program,
63-
and therefore before doing any introduction of variable or any reuse of already introduced
63+
and therefore before doing any introduction of var or any reuse of already introduced
6464
variables, we will make sure that the computation being considered is not forbidden, and
6565
that it does not even contain a forbidden computation.
6666
* \param expr The expression to check
@@ -120,6 +120,42 @@ bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExp
120120
return true;
121121
}
122122

123+
/*!
124+
* \brief Implements an order on pairs (expression,frequency). First attempts to compare them
125+
using the size of the expression. If it is the same, decides something else still
126+
deterministic.
127+
* \param a The first pair
128+
* \param b The second pair
129+
* \return A boolean telling if the first pair `a` comes before the second pair `b`
130+
* \note We need this order to be deterministic in order to have a fully deterministic pass,
131+
* as we will deal with elements that are coming from a hashtable, but the order in which
132+
* they appeared in the hashtable was based on some runtime addresses, so it can potentially
133+
* change with every execution.
134+
*/
135+
bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a,
136+
std::pair<PrimExpr, size_t> b) {
137+
size_t a_size = CalculateExprComplexity(a.first);
138+
size_t b_size = CalculateExprComplexity(b.first);
139+
140+
// Criteria 1 - Size of the expression comes first
141+
// `a` comes before `b` if the size of `a` is bigger
142+
if (a_size > b_size) {
143+
return true;
144+
}
145+
// `a` does NOT come before `b` if the size of `b` is bigger
146+
if (b_size > a_size) {
147+
return false;
148+
}
149+
150+
// Criteria 2 - If they had the same size, use the lexicographic order as a last resort
151+
// as we need a deterministic order
152+
std::stringstream a_stream;
153+
std::stringstream b_stream;
154+
a_stream << a.first;
155+
b_stream << b.first;
156+
return (a_stream.str().compare(b_stream.str()) < 0);
157+
}
158+
123159
/*!
124160
* \brief Generates a new fresh variable, whose name will be cse_var_i.
125161
* \param type_annotation The type of the new variable to generate
@@ -166,10 +202,12 @@ int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
166202
of the function being analyzed
167203
* \return A new statement where CSE has been performed
168204
*/
169-
Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
205+
Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init,
206+
bool identify_equiv_terms) {
170207
// As this function is being called for each PrimFunc definition, we create a new instance
171208
// for the one we are having now.
172-
CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
209+
CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init,
210+
identify_equiv_terms);
173211
return common_subexpression_eliminator.VisitStmt(stmt);
174212
}
175213

@@ -179,8 +217,9 @@ Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context&
179217
formal parameters of the function that will be analyzed
180218
*/
181219
CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
182-
const Context& context_init)
183-
: initial_body_(stmt), context_(context_init) {}
220+
const Context& context_init,
221+
bool identify_equiv_terms)
222+
: initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {}
184223

185224
/*!
186225
* \brief The method which overrides the generic dispatcher of StmtExprMutator.
@@ -200,39 +239,40 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
200239
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
201240
// containing *semantic* entities, i.e. where equivalent computations are merged.
202241
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
203-
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
242+
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_);
204243

205244
// Sort the vector of semantic entities by decreasing size
206245
std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
207-
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
208-
return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
209-
});
246+
OrderOnExprAndFrequency);
210247

211248
// For each computation done (considering them from biggest to smallest)
212249
for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
213250
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
214251

252+
bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"
253+
215254
// The predicate later used (when doing replacements) to select expressions that are
216255
// equivalent to the current computation (`computation_and_nb.first`)
217256
std::function<bool(const PrimExpr&)> predicate_selector =
218-
[computation_and_nb](const PrimExpr& current_expr) {
257+
[computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
219258
// `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
220259
// that `current_expr` is an eligible computation even if we know that
221260
// `computation_and_nb.first` is eligible by construction, in case that one day the
222261
// equivalence relation would not preserve the eligibility any more (even though that
223262
// would probably be a very weird equivalence).
224-
return (EquivalentTerms(current_expr, computation_and_nb.first) &&
263+
return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
225264
IsEligibleComputation(current_expr));
226265
};
227266

228267
// See if there is a pair (`var`, `value`) in the context where `value` is semantically
229268
// equivalent to `computation_and_nb.first`
230269
auto it_on_var = std::find_if(
231270
context_.begin(), context_.end(),
232-
[computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
271+
[computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
233272
// Note : safe to call value() as we check has_value() just before
234273
return (var_and_value.second.has_value() &&
235-
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
274+
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
275+
ident_equiv_terms));
236276
});
237277

238278
// Case where we have a perfectly equivalent computation already available in a variable
@@ -298,7 +338,8 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
298338
// The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
299339
// decreasing size/complexity), and it will only insert at locations > i as the
300340
// direct subexprs are necessarily smaller than the current computation.
301-
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs);
341+
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs,
342+
identify_equiv_terms_);
302343
}
303344
}
304345
// Note : we do not remove the current element, as we never look back in the local vector
@@ -378,39 +419,40 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
378419
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
379420
// containing *semantic* entities, i.e. where equivalent computations are merged.
380421
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
381-
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt);
422+
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_);
382423

383424
// Sort the vector of semantic entities by decreasing size
384425
std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(),
385-
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
386-
return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
387-
});
426+
OrderOnExprAndFrequency);
388427

389428
// For each computation done (considering them from biggest to smallest)
390429
for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
391430
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i];
392431

432+
bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"
433+
393434
// The predicate later used (when doing replacements) to select expressions that are
394435
// equivalent to the current computation (`computation_and_nb.first`)
395436
std::function<bool(const PrimExpr&)> predicate_selector =
396-
[computation_and_nb](const PrimExpr& current_expr) {
437+
[computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
397438
// `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
398439
// that `current_expr` is an eligible computation even if we know that
399440
// `computation_and_nb.first` is eligible by construction, in case that one day the
400441
// equivalence relation would not preserve the eligibility any more (even though that
401442
// would probably be a very weird equivalence).
402-
return (EquivalentTerms(current_expr, computation_and_nb.first) &&
443+
return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
403444
IsEligibleComputation(current_expr));
404445
};
405446

406447
// See if there is a pair (`var`, `value`) in the context where `value` is semantically
407448
// equivalent to `computation_and_nb.first`
408449
auto it_on_var = std::find_if(
409450
context_.begin(), context_.end(),
410-
[computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
451+
[computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
411452
// Note : safe to call value() as we check has_value() just before
412453
return (var_and_value.second.has_value() &&
413-
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
454+
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
455+
ident_equiv_terms));
414456
});
415457

416458
// Case where we have a perfectly equivalent computation already available in a variable
@@ -477,7 +519,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
477519
// The following insertion will maintain `semantic_comp_done_by_stmt` sorted (by
478520
// decreasing size/complexity), and it will only insert at locations > i as the
479521
// direct subexprs are necessarily smaller than the current computation.
480-
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs);
522+
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs,
523+
identify_equiv_terms_);
481524
}
482525
}
483526
// Note : we do not remove the current element, as we never look back in the local vector
@@ -587,8 +630,8 @@ namespace transform {
587630
* \brief The function which returns the pass for the Common Subexpression Elimination.
588631
* \return The pass for performing CSE.
589632
*/
590-
Pass CommonSubexprElimTIR(bool enable_cse_tir) {
591-
auto pass_func = [enable_cse_tir](PrimFunc f, IRModule m, PassContext ctx) {
633+
Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) {
634+
auto pass_func = [enable_cse_tir, identify_equiv_terms](PrimFunc f, IRModule m, PassContext ctx) {
592635
if (enable_cse_tir) {
593636
auto* n = f.CopyOnWrite();
594637
Context context_init;
@@ -603,7 +646,8 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir) {
603646

604647
// Do the Common Subexpression Elimination on the body of the function, with the initial
605648
// context that we have prepared
606-
n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init);
649+
n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init,
650+
identify_equiv_terms);
607651
}
608652

609653
return f;

src/tir/transforms/common_subexpr_elim.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ using Context = std::vector<std::pair<Var, MaybeValue>>;
5555
class CommonSubexpressionEliminator : public StmtExprMutator {
5656
public:
5757
// Toplevel (static) function
58-
static Stmt PerformCSE(const Stmt& stmt, const Context& context_init);
58+
static Stmt PerformCSE(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms);
5959

6060
PrimExpr VisitExpr(const PrimExpr& expr) override;
6161
Stmt VisitStmt(const Stmt& stmt) override;
@@ -64,7 +64,8 @@ class CommonSubexpressionEliminator : public StmtExprMutator {
6464

6565
protected:
6666
// Constructor
67-
CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init);
67+
CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init,
68+
bool identify_equiv_terms);
6869

6970
PrimExpr VisitExpr_(const LetNode* op) override;
7071

@@ -77,9 +78,12 @@ class CommonSubexpressionEliminator : public StmtExprMutator {
7778
int num_last_try_ = 0; // Number of the last variable tried
7879
int nb_var_ = 0; // Number of variables introduced by the CSE pass
7980

81+
bool identify_equiv_terms_ = false;
82+
8083
static bool ForbiddenComputation(const PrimExpr& expr);
8184
static bool IsEligibleComputation(const PrimExpr& expr);
8285
static bool CanContainEligibleComputations(const PrimExpr& expr);
86+
static bool OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b);
8387
Var GenerateNewVar(DataType type_annotation);
8488
};
8589

0 commit comments

Comments
 (0)