diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index 218667c331a5..d39d211ba182 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -743,13 +743,27 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) { std::vector> SyntacticToSemanticComputations( const ComputationTable& table) { std::vector> result; + // table.size() is an upper-bound of the number of elements in the resulting vector, // as we might merge semantically equivalent computations. // We do this reservation even if it might reserve slightly more space than is needed in the end result.reserve(table.size()); + // Traverse through map in a sorted order on keys to maintain deterministic behavior + // We do this by comparing the string repr of each PrimExpr to get a determinstic ordering + std::vector> sorted_map_items(table.begin(), table.end()); + + sort(sorted_map_items.begin(), sorted_map_items.end(), + [](std::pair a, std::pair b) { + std::stringstream a_stream; + std::stringstream b_stream; + a_stream << a.first; + b_stream << b.first; + return a_stream.str().compare(b_stream.str()) < 0; + }); + // For each element in the hashtable - for (auto elem : table) { + for (auto elem : sorted_map_items) { // We try to see if a semantically equivalent term is already in the resulting vector auto it_found = std::find_if(result.begin(), result.end(), [elem](std::pair already_seen) { @@ -763,7 +777,6 @@ std::vector> SyntacticToSemanticComputations( result.push_back(elem); } } - return result; } diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 17c0cbdd99c6..c12e27a46e3f 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -14,8 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import hashlib + import tvm from tvm import te +from tvm.ir.base import save_json +from tvm.ir.module import IRModule + # A test program which gives the opportunity for the CSE pass to introduce two new variables, at two different levels def test_cse(): @@ -133,6 +138,49 @@ def test_cse(): assert isinstance(body.body, tvm.tir.BufferStore) +def test_deterministic_cse(): + import random + + """Test deterministic allocation of CSE vars + + We expect something like + + result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3) + --> + cse_var_3 = (x + 1) + cse_var_2 = (x + 2) + cse_var_1 = (x + 3) + result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + cse_var_1 + """ + NUM_TERMS = 10 + REPEATS = 10 + + x = te.var("x") + result = te.var("result") + + offsets = sorted([i + 1 for i in range(NUM_TERMS)]) + inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)] + inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)] + + expression = x + for add in inc1 + inc2: + expression = expression + add + let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt)) + + initial_hash = None + for _ in range(REPEATS): + body = tvm.tir.transform.CommonSubexprElimTIR()(mod)["main"] + + # Hash and ensure serialize json is the same every time + json_val = save_json(body) + json_hash = hashlib.sha256(json_val.encode()).hexdigest() + + if initial_hash is None: + initial_hash = json_hash + assert json_hash == initial_hash + + # First specific test for if nodes : Some duplicated computations appear only in one branch (here the Then branch), not in both branches. # In this case, the CSE pass should introduce the redundant computation at the top if the Then branch, not before the whole If # (otherwise that would lead to some computations being computed for nothing when it is the Else branch that is executed).