diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index a30a2c59d0d1..515d37d7c008 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -79,9 +79,11 @@ class StructuralHash : public BaseValueHash { /*! * \brief Compute structural hashing value for an object. * \param key The left operand. + * \param map_free_vars Whether to map free variables by their occurrence + * number. Otherwise the underlying pointer value is used. * \return The hash value. */ - TVM_DLL size_t operator()(const ObjectRef& key) const; + TVM_DLL size_t operator()(const ObjectRef& key, bool map_free_vars = false) const; }; /*! diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 8306cb173e0a..57231b6ed28e 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -56,6 +56,16 @@ struct ExprDeepEqual { TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const; }; +/*! + * \brief Hash of primexpr without var remapping. + * + * \sa ExprDeepEqual + * + * \param e PrimExpr to hash + * \return The hash of the PrimExpr + */ +int64_t ExprDeepHash(const PrimExpr& e); + /*! * \brief Visit the PrimFuncs in the IRModule * \tparam FLambda The type of the PrimFunc visitor diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index e97e5f41bfc2..023e5ba0fe55 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -261,8 +261,8 @@ TVM_REGISTER_GLOBAL("node.StructuralHash") return static_cast(hashed_value); }); -size_t StructuralHash::operator()(const ObjectRef& object) const { - return VarCountingSHashHandler().Hash(object, false); +size_t StructuralHash::operator()(const ObjectRef& object, bool map_free_vars) const { + return VarCountingSHashHandler().Hash(object, map_free_vars); } // SEQualReduce traits for runtime containers. diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 7f48cc439234..10329a0114cd 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include @@ -65,6 +66,8 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); } +int64_t ExprDeepHash(const PrimExpr& e) { return StructuralHash()(e, false); } + TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") .set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { return ExprDeepEqual()(lhs, rhs); diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index d39d211ba182..1a107a3d2846 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -35,7 +35,8 @@ #include #include // For the declaration of the pass -#include // For std::find_if +#include // For std::find_if +#include #include // For the hashtable datatype #include #include @@ -730,6 +731,10 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) { return EqualTerms(a, b); } +struct ExprDeepHashStruct { + int64_t operator()(const PrimExpr& e) const { return ExprDeepHash(e); } +}; + /*! * \brief Transforms a hashtable of syntactic computations into a vector or pairs (expression, counter) where equivalent computations are merged and their counters added. @@ -742,41 +747,32 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) { */ std::vector> SyntacticToSemanticComputations( const ComputationTable& table) { - std::vector> result; - - // table.size() is an upper-bound of the number of elements in the resulting vector, - // as we might merge semantically equivalent computations. - // We do this reservation even if it might reserve slightly more space than is needed in the end - result.reserve(table.size()); - - // Traverse through map in a sorted order on keys to maintain deterministic behavior - // We do this by comparing the string repr of each PrimExpr to get a determinstic ordering - std::vector> sorted_map_items(table.begin(), table.end()); - - sort(sorted_map_items.begin(), sorted_map_items.end(), - [](std::pair a, std::pair b) { - std::stringstream a_stream; - std::stringstream b_stream; - a_stream << a.first; - b_stream << b.first; - return a_stream.str().compare(b_stream.str()) < 0; - }); + std::unordered_map equiv_computations; // For each element in the hashtable - for (auto elem : sorted_map_items) { + for (auto elem : table) { // We try to see if a semantically equivalent term is already in the resulting vector - auto it_found = std::find_if(result.begin(), result.end(), - [elem](std::pair already_seen) { - return EquivalentTerms(already_seen.first, elem.first); - }); + auto it_found = equiv_computations.find(elem.first); // And if so, we increase (by `elem.second`) its count - if (it_found != result.end()) { + if (it_found != equiv_computations.end()) { it_found->second += elem.second; } else { // If we could not find a semantically equivalent term in the resulting vector, we add it - result.push_back(elem); + equiv_computations[elem.first] = elem.second; } } + std::vector> result(equiv_computations.begin(), + equiv_computations.end()); + // Sort results to maintain deterministic behavior + // We do this by comparing the string repr of each PrimExpr to get a determinstic ordering + std::sort(result.begin(), result.end(), + [](std::pair a, std::pair b) { + std::stringstream a_stream; + std::stringstream b_stream; + a_stream << a.first; + b_stream << b.first; + return a_stream.str().compare(b_stream.str()) < 0; + }); return result; }