Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
561e8c4
Restricting the equivalence relation to be deciding by a normalizatio…
FranckQC May 31, 2022
dbb9a03
Perfs tests being done
FranckQC Jun 1, 2022
bcd46b5
Better, without using two type of vectors
FranckQC Jun 2, 2022
b3ab028
Tests for equivalence added, copied from @yuanfz98 in his PR: [TIR] S…
FranckQC Jun 2, 2022
4264e05
Improve comment.
FranckQC Jun 2, 2022
5091d17
Avoid having to do two sorts (one for having a deterministic order, a…
FranckQC Jun 2, 2022
18f1453
Added booleans to turn on/off the identification of terms modulo an e…
FranckQC Jun 2, 2022
5448bf2
Reordering tests
FranckQC Jun 2, 2022
8ce8080
We really need the ordering of the hashtable before iterating through…
FranckQC Jun 2, 2022
8da33b4
We need one sorting before each time we iterate through a hashtable. …
FranckQC Jun 2, 2022
7782c37
Formatting, comments, style, etc
FranckQC Jun 3, 2022
1f80d6f
Add way to activate the equivalence of terms in CSE from the outside …
FranckQC Jun 3, 2022
25a7e7c
Remove printing of perfs and make the second test on determinism a re…
FranckQC Jun 3, 2022
078a914
Fixe formatting issues reported by linter on CI
FranckQC Jun 4, 2022
6fc7c03
Another linter issue fixed
FranckQC Jun 4, 2022
5fc570e
Final linter issues.
FranckQC Jun 4, 2022
6cd97f6
final final linter issues, perhaps.
FranckQC Jun 4, 2022
e12ccb2
is it the end?
FranckQC Jun 4, 2022
b5ca16f
Some more!
FranckQC Jun 4, 2022
77a8da6
One more, again!
FranckQC Jun 4, 2022
c701e51
Transform directly the hashtable into a vector when the boolean ident…
FranckQC Jun 7, 2022
e733e5a
Curly braces around then block
FranckQC Jun 8, 2022
f605363
Update common_subexpr_elim.cc
FranckQC Jun 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,10 @@ TVM_DLL Pass LowerVtcmAlloc();
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
* \param enable_cse_tir Whether common subexpression elimination is enabled.
* \param identify_equiv_terms Whether equivalent terms should be identified.
* \return The pass.
*/
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true);
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ def BF16TypeLowering():
return _ffi_api.BF16TypeLowering() # type: ignore


def CommonSubexprElimTIR(enable_cse_tir: bool = True):
def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):
"""Replace redundant computations by new variables.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.CommonSubexprElimTIR(enable_cse_tir) # type: ignore
return _ffi_api.CommonSubexprElimTIR(enable_cse_tir, identify_equiv_terms) # type: ignore


def RewriteUnsafeSelect():
Expand Down
6 changes: 5 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
Expand Down Expand Up @@ -198,6 +199,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
bool instrument_bound_checkers =
pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
bool enable_equiv_terms_in_cse_tir =
pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value();

// Get any user-added passes
Array<Array<ObjectRef>> add_lower_pass =
Expand Down Expand Up @@ -289,7 +292,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InstrumentBoundCheckers());
}

pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir));
pass_list.push_back(
tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir));

return pass_list;
}
Expand Down
96 changes: 70 additions & 26 deletions src/tir/transforms/common_subexpr_elim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace tir {
to collect them for the CSE pass, but we also won't even want to collect computations
that contain them.
The reason is that reusing such computations would change the semantics of the program,
and therefore before doing any introduction of variable or any reuse of already introduced
and therefore before doing any introduction of var or any reuse of already introduced
variables, we will make sure that the computation being considered is not forbidden, and
that it does not even contain a forbidden computation.
* \param expr The expression to check
Expand Down Expand Up @@ -120,6 +120,42 @@ bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExp
return true;
}

/*!
* \brief Implements an order on pairs (expression,frequency). First attempts to compare them
using the size of the expression. If it is the same, decides something else still
deterministic.
* \param a The first pair
* \param b The second pair
* \return A boolean telling if the first pair `a` comes before the second pair `b`
* \note We need this order to be deterministic in order to have a fully deterministic pass,
* as we will deal with elements that are coming from a hashtable, but the order in which
* they appeared in the hashtable was based on some runtime addresses, so it can potentially
* change with every execution.
*/
bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a,
std::pair<PrimExpr, size_t> b) {
size_t a_size = CalculateExprComplexity(a.first);
size_t b_size = CalculateExprComplexity(b.first);

// Criteria 1 - Size of the expression comes first
// `a` comes before `b` if the size of `a` is bigger
if (a_size > b_size) {
return true;
}
// `a` does NOT come before `b` if the size of `b` is bigger
if (b_size > a_size) {
return false;
}

// Criteria 2 - If they had the same size, use the lexicographic order as a last resort
// as we need a deterministic order
std::stringstream a_stream;
std::stringstream b_stream;
a_stream << a.first;
b_stream << b.first;
return (a_stream.str().compare(b_stream.str()) < 0);
}

/*!
* \brief Generates a new fresh variable, whose name will be cse_var_i.
* \param type_annotation The type of the new variable to generate
Expand Down Expand Up @@ -166,10 +202,12 @@ int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
of the function being analyzed
* \return A new statement where CSE has been performed
*/
Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init) {
Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init,
bool identify_equiv_terms) {
// As this function is being called for each PrimFunc definition, we create a new instance
// for the one we are having now.
CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init);
CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init,
identify_equiv_terms);
return common_subexpression_eliminator.VisitStmt(stmt);
}

Expand All @@ -179,8 +217,9 @@ Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context&
formal parameters of the function that will be analyzed
*/
CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
const Context& context_init)
: initial_body_(stmt), context_(context_init) {}
const Context& context_init,
bool identify_equiv_terms)
: initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {}

/*!
* \brief The method which overrides the generic dispatcher of StmtExprMutator.
Expand All @@ -200,39 +239,40 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_);

// Sort the vector of semantic entities by decreasing size
std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
});
OrderOnExprAndFrequency);

// For each computation done (considering them from biggest to smallest)
for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];

bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"

// The predicate later used (when doing replacements) to select expressions that are
// equivalent to the current computation (`computation_and_nb.first`)
std::function<bool(const PrimExpr&)> predicate_selector =
[computation_and_nb](const PrimExpr& current_expr) {
[computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
// `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
// that `current_expr` is an eligible computation even if we know that
// `computation_and_nb.first` is eligible by construction, in case that one day the
// equivalence relation would not preserve the eligibility any more (even though that
// would probably be a very weird equivalence).
return (EquivalentTerms(current_expr, computation_and_nb.first) &&
return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
IsEligibleComputation(current_expr));
};

// See if there is a pair (`var`, `value`) in the context where `value` is semantically
// equivalent to `computation_and_nb.first`
auto it_on_var = std::find_if(
context_.begin(), context_.end(),
[computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
[computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
// Note : safe to call value() as we check has_value() just before
return (var_and_value.second.has_value() &&
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
ident_equiv_terms));
});

// Case where we have a perfectly equivalent computation already available in a variable
Expand Down Expand Up @@ -298,7 +338,8 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
// The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
// decreasing size/complexity), and it will only insert at locations > i as the
// direct subexprs are necessarily smaller than the current computation.
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs);
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs,
identify_equiv_terms_);
}
}
// Note : we do not remove the current element, as we never look back in the local vector
Expand Down Expand Up @@ -378,39 +419,40 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt);
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_);

// Sort the vector of semantic entities by decreasing size
std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(),
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));
});
OrderOnExprAndFrequency);

// For each computation done (considering them from biggest to smallest)
for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i];

bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"

// The predicate later used (when doing replacements) to select expressions that are
// equivalent to the current computation (`computation_and_nb.first`)
std::function<bool(const PrimExpr&)> predicate_selector =
[computation_and_nb](const PrimExpr& current_expr) {
[computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
// `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
// that `current_expr` is an eligible computation even if we know that
// `computation_and_nb.first` is eligible by construction, in case that one day the
// equivalence relation would not preserve the eligibility any more (even though that
// would probably be a very weird equivalence).
return (EquivalentTerms(current_expr, computation_and_nb.first) &&
return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
IsEligibleComputation(current_expr));
};

// See if there is a pair (`var`, `value`) in the context where `value` is semantically
// equivalent to `computation_and_nb.first`
auto it_on_var = std::find_if(
context_.begin(), context_.end(),
[computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
[computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
// Note : safe to call value() as we check has_value() just before
return (var_and_value.second.has_value() &&
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first));
EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
ident_equiv_terms));
});

// Case where we have a perfectly equivalent computation already available in a variable
Expand Down Expand Up @@ -477,7 +519,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
// The following insertion will maintain `semantic_comp_done_by_stmt` sorted (by
// decreasing size/complexity), and it will only insert at locations > i as the
// direct subexprs are necessarily smaller than the current computation.
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs);
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs,
identify_equiv_terms_);
}
}
// Note : we do not remove the current element, as we never look back in the local vector
Expand Down Expand Up @@ -587,8 +630,8 @@ namespace transform {
* \brief The function which returns the pass for the Common Subexpression Elimination.
* \return The pass for performing CSE.
*/
Pass CommonSubexprElimTIR(bool enable_cse_tir) {
auto pass_func = [enable_cse_tir](PrimFunc f, IRModule m, PassContext ctx) {
Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) {
auto pass_func = [enable_cse_tir, identify_equiv_terms](PrimFunc f, IRModule m, PassContext ctx) {
if (enable_cse_tir) {
auto* n = f.CopyOnWrite();
Context context_init;
Expand All @@ -603,7 +646,8 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir) {

// Do the Common Subexpression Elimination on the body of the function, with the initial
// context that we have prepared
n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init);
n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init,
identify_equiv_terms);
}

return f;
Expand Down
8 changes: 6 additions & 2 deletions src/tir/transforms/common_subexpr_elim.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ using Context = std::vector<std::pair<Var, MaybeValue>>;
class CommonSubexpressionEliminator : public StmtExprMutator {
public:
// Toplevel (static) function
static Stmt PerformCSE(const Stmt& stmt, const Context& context_init);
static Stmt PerformCSE(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms);

PrimExpr VisitExpr(const PrimExpr& expr) override;
Stmt VisitStmt(const Stmt& stmt) override;
Expand All @@ -64,7 +64,8 @@ class CommonSubexpressionEliminator : public StmtExprMutator {

protected:
// Constructor
CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init);
CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init,
bool identify_equiv_terms);

PrimExpr VisitExpr_(const LetNode* op) override;

Expand All @@ -77,9 +78,12 @@ class CommonSubexpressionEliminator : public StmtExprMutator {
int num_last_try_ = 0; // Number of the last variable tried
int nb_var_ = 0; // Number of variables introduced by the CSE pass

bool identify_equiv_terms_ = false;

static bool ForbiddenComputation(const PrimExpr& expr);
static bool IsEligibleComputation(const PrimExpr& expr);
static bool CanContainEligibleComputations(const PrimExpr& expr);
static bool OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b);
Var GenerateNewVar(DataType type_annotation);
};

Expand Down
Loading