Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ class StructuralHash : public BaseValueHash {
/*!
* \brief Compute structural hashing value for an object.
* \param key The left operand.
* \param map_free_vars Whether to map free variables by their occurrence
* number. Otherwise the underlying pointer value is used.
* \return The hash value.
*/
TVM_DLL size_t operator()(const ObjectRef& key) const;
TVM_DLL size_t operator()(const ObjectRef& key, bool map_free_vars = false) const;
};

/*!
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ struct ExprDeepEqual {
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};

/*!
* \brief Hash of primexpr without var remapping.
*
* \sa ExprDeepEqual
*
* \param e PrimExpr to hash
* \return The hash of the PrimExpr
*/
int64_t ExprDeepHash(const PrimExpr& e);

/*!
* \brief Visit the PrimFuncs in the IRModule
* \tparam FLambda The type of the PrimFunc visitor
Expand Down
4 changes: 2 additions & 2 deletions src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ TVM_REGISTER_GLOBAL("node.StructuralHash")
return static_cast<int64_t>(hashed_value);
});

size_t StructuralHash::operator()(const ObjectRef& object) const {
return VarCountingSHashHandler().Hash(object, false);
size_t StructuralHash::operator()(const ObjectRef& object, bool map_free_vars) const {
return VarCountingSHashHandler().Hash(object, map_free_vars);
}

// SEQualReduce traits for runtime containers.
Expand Down
3 changes: 3 additions & 0 deletions src/tir/analysis/deep_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/node/reflection.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>

Expand Down Expand Up @@ -65,6 +66,8 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false);
}

int64_t ExprDeepHash(const PrimExpr& e) { return StructuralHash()(e, false); }

TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal")
.set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) {
return ExprDeepEqual()(lhs, rhs);
Expand Down
50 changes: 23 additions & 27 deletions src/tir/transforms/common_subexpr_elim_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> // For the declaration of the pass

#include <algorithm> // For std::find_if
#include <algorithm> // For std::find_if
#include <map>
#include <unordered_map> // For the hashtable datatype
#include <utility>
#include <vector>
Expand Down Expand Up @@ -730,6 +731,10 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) {
return EqualTerms(a, b);
}

struct ExprDeepHashStruct {
int64_t operator()(const PrimExpr& e) const { return ExprDeepHash(e); }
};

/*!
* \brief Transforms a hashtable of syntactic computations into a vector or pairs
(expression, counter) where equivalent computations are merged and their counters added.
Expand All @@ -742,41 +747,32 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) {
*/
std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
const ComputationTable& table) {
std::vector<std::pair<PrimExpr, size_t>> result;

// table.size() is an upper-bound of the number of elements in the resulting vector,
// as we might merge semantically equivalent computations.
// We do this reservation even if it might reserve slightly more space than is needed in the end
result.reserve(table.size());

// Traverse through map in a sorted order on keys to maintain deterministic behavior
// We do this by comparing the string repr of each PrimExpr to get a determinstic ordering
std::vector<std::pair<PrimExpr, size_t>> sorted_map_items(table.begin(), table.end());

sort(sorted_map_items.begin(), sorted_map_items.end(),
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
std::stringstream a_stream;
std::stringstream b_stream;
a_stream << a.first;
b_stream << b.first;
return a_stream.str().compare(b_stream.str()) < 0;
});
std::unordered_map<PrimExpr, size_t, ExprDeepHashStruct, ExprDeepEqual> equiv_computations;

// For each element in the hashtable
for (auto elem : sorted_map_items) {
for (auto elem : table) {
// We try to see if a semantically equivalent term is already in the resulting vector
auto it_found = std::find_if(result.begin(), result.end(),
[elem](std::pair<PrimExpr, size_t> already_seen) {
return EquivalentTerms(already_seen.first, elem.first);
});
auto it_found = equiv_computations.find(elem.first);
// And if so, we increase (by `elem.second`) its count
if (it_found != result.end()) {
if (it_found != equiv_computations.end()) {
it_found->second += elem.second;
} else {
// If we could not find a semantically equivalent term in the resulting vector, we add it
result.push_back(elem);
equiv_computations[elem.first] = elem.second;
}
}
std::vector<std::pair<PrimExpr, size_t>> result(equiv_computations.begin(),
equiv_computations.end());
// Sort results to maintain deterministic behavior
// We do this by comparing the string repr of each PrimExpr to get a determinstic ordering
std::sort(result.begin(), result.end(),
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
std::stringstream a_stream;
std::stringstream b_stream;
a_stream << a.first;
b_stream << b.first;
return a_stream.str().compare(b_stream.str()) < 0;
});
return result;
}

Expand Down