Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 154 additions & 25 deletions src/tir/analysis/deep_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,48 +25,177 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/object_path.h>
#include <tvm/node/reflection.h>
#include <tvm/node/structural_equal.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr_functor.h>

namespace tvm {
namespace tir {

class DeepCmpSEqualHandler : public SEqualReducer::Handler {
#define DEFINE_DEEP_EQUAL_BIN_EXPR(OpNode) \
bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \
const auto* prhs = rhs.as<OpNode>(); \
return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a) && \
VisitExpr(plhs->b, prhs->b); \
}

#define DEFINE_DEEP_EQUAL_IMM_EXPR(OpNode) \
bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \
const auto* prhs = rhs.as<OpNode>(); \
return plhs->dtype == prhs->dtype && plhs->value == prhs->value; \
}

class ExprDeepEqualChecker : private ExprFunctor<bool(const PrimExpr&, const PrimExpr&)> {
public:
// use direct recursion.
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>&) final {
static bool Check(const PrimExpr& lhs, const PrimExpr& rhs) {
// quick path without constructing the object
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, nullptr, false)) &&
!fail_;
if (auto* plhs = lhs.as<IntImmNode>()) {
auto* prhs = rhs.as<IntImmNode>();
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
}
return ExprDeepEqualChecker().VisitExpr(lhs, rhs);
}

void DeferFail(const ObjectPathPair&) final { fail_ = true; }
bool IsFailDeferralEnabled() final { return false; }

ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return lhs; }
void MarkGraphNode() final {}
bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
return ExprFunctor::VisitExpr(lhs, rhs);
}

private:
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
bool fail_ = false;
bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); i++) {
if (!VisitExpr(lhs[i], rhs[i])) return false;
}
return true;
}

bool ArrayDeepEqual(const Array<IterVar>& lhs, const Array<IterVar>& rhs) {
// for iter var, we require pointer equality
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); i++) {
if (!lhs[i].same_as(rhs[i])) return true;
}
return true;
}

bool OptionalDeepEqual(const Optional<PrimExpr>& lhs, const Optional<PrimExpr>& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (lhs.defined() && !rhs.defined()) return false;
return VisitExpr(*lhs, *rhs);
}

bool VisitExpr_(const VarNode* plhs, const PrimExpr& rhs) final {
// for var, we require pointer equality
return plhs == rhs.get();
}

bool VisitExpr_(const SizeVarNode* plhs, const PrimExpr& rhs) final {
// for var, we require pointer equality
return plhs == rhs.get();
}

bool VisitExpr_(const BufferLoadNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<BufferLoadNode>();
// we run pointer comparison of the buffer
return plhs->dtype == prhs->dtype && plhs->buffer.same_as(prhs->buffer) &&
ArrayDeepEqual(plhs->indices, prhs->indices) &&
OptionalDeepEqual(plhs->predicate, prhs->predicate);
}

bool VisitExpr_(const ProducerLoadNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<ProducerLoadNode>();
// run shallow pointer comparison of the producer
return plhs->dtype == prhs->dtype && plhs->producer.same_as(prhs->producer) &&
ArrayDeepEqual(plhs->indices, prhs->indices);
}

bool VisitExpr_(const LetNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<LetNode>();
return plhs->dtype == prhs->dtype && VisitExpr(plhs->var, prhs->var) &&
VisitExpr(plhs->value, prhs->value) && VisitExpr(plhs->body, prhs->body);
}

bool VisitExpr_(const CallNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<CallNode>();
return plhs->dtype == prhs->dtype && plhs->op.same_as(prhs->op) &&
ArrayDeepEqual(plhs->args, prhs->args);
}

bool VisitExpr_(const ReduceNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<ReduceNode>();
return plhs->dtype == prhs->dtype && plhs->combiner.same_as(prhs->combiner) &&
ArrayDeepEqual(plhs->source, prhs->source) && ArrayDeepEqual(plhs->init, prhs->init) &&
ArrayDeepEqual(plhs->axis, prhs->axis) && VisitExpr(plhs->condition, prhs->condition) &&
plhs->value_index == prhs->value_index;
}

bool VisitExpr_(const CastNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<CastNode>();
return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value);
}

bool VisitExpr_(const NotNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<NotNode>();
return plhs->dtype == prhs->dtype && VisitExpr(plhs->a, prhs->a);
}

bool VisitExpr_(const SelectNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<SelectNode>();
return plhs->dtype == prhs->dtype && VisitExpr(plhs->condition, prhs->condition) &&
VisitExpr(plhs->true_value, prhs->true_value) &&
VisitExpr(plhs->false_value, prhs->false_value);
}

bool VisitExpr_(const RampNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<RampNode>();
return plhs->dtype == prhs->dtype && VisitExpr(plhs->base, prhs->base) &&
VisitExpr(plhs->stride, prhs->stride) && VisitExpr(plhs->lanes, prhs->lanes);
}

bool VisitExpr_(const ShuffleNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<ShuffleNode>();
return plhs->dtype == prhs->dtype && ArrayDeepEqual(plhs->vectors, prhs->vectors) &&
ArrayDeepEqual(plhs->indices, prhs->indices);
}

bool VisitExpr_(const BroadcastNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<BroadcastNode>();
return plhs->dtype == prhs->dtype && VisitExpr(plhs->value, prhs->value) &&
VisitExpr(plhs->lanes, prhs->lanes);
}

DEFINE_DEEP_EQUAL_BIN_EXPR(AddNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(SubNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(MulNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(DivNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(ModNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(FloorDivNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(FloorModNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(MinNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(MaxNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(EQNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(NENode)
DEFINE_DEEP_EQUAL_BIN_EXPR(LTNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(LENode)
DEFINE_DEEP_EQUAL_BIN_EXPR(GTNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(GENode)
DEFINE_DEEP_EQUAL_BIN_EXPR(AndNode)
DEFINE_DEEP_EQUAL_BIN_EXPR(OrNode)
DEFINE_DEEP_EQUAL_IMM_EXPR(IntImmNode)
DEFINE_DEEP_EQUAL_IMM_EXPR(FloatImmNode)
DEFINE_DEEP_EQUAL_IMM_EXPR(StringImmNode)
};

bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
// quick path
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
if (auto* plhs = lhs.as<IntImmNode>()) {
auto* prhs = rhs.as<IntImmNode>();
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
}
return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt);
return ExprDeepEqualChecker::Check(lhs, rhs);
}

TVM_FFI_STATIC_INIT_BLOCK({
Expand Down
Loading