Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ project(

option(TVM_FFI_BUILD_TESTS "Adding test targets." OFF)
option(TVM_FFI_USE_LIBBACKTRACE "Enable libbacktrace" ON)
option(TVM_FFI_USE_EXTRA_CXX_API "Enable extra CXX API in shared lib" ON)
option(TVM_FFI_BACKTRACE_ON_SEGFAULT "Set signal handler to print traceback on segfault" ON)

include(cmake/Utils/CxxWarning.cmake)
Expand All @@ -47,7 +48,8 @@ target_include_directories(tvm_ffi_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}
target_link_libraries(tvm_ffi_header INTERFACE dlpack_header)

########## Target: `tvm_ffi` ##########
add_library(tvm_ffi_objs OBJECT

set(tvm_ffi_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback_win.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/object.cc"
Expand All @@ -57,10 +59,18 @@ add_library(tvm_ffi_objs OBJECT
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
)

if (TVM_FFI_USE_EXTRA_CXX_API)
list(APPEND tvm_ffi_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
)
endif()

add_library(tvm_ffi_objs OBJECT ${tvm_ffi_objs_sources})

set_target_properties(
tvm_ffi_objs PROPERTIES
POSITION_INDEPENDENT_CODE ON
Expand Down
25 changes: 24 additions & 1 deletion ffi/include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,27 @@ typedef enum {
* is only an unique copy of each value.
*/
kTVMFFISEqHashKindUniqueInstance = 5,
/*!
* \brief provide custom __s_equal__ and __s_hash__ functions through TypeAttrColumn.
*
* The function signatures are(defined via ffi::Function)
*
* \code
* bool __s_equal__(
* ObjectRefType self, ObjectRefType other,
* ffi::TypedFunction<bool(AnyView, AnyView, bool def_region, string field_name)> cmp,
* );
*
* uint64_t __s_hash__(
* ObjectRefType self, uint64_t type_key_hash,
* ffi::TypedFunction<uint64_t(AnyView, bool def_region)> hash
* );
* \endcode
*
* Where the extra string field in cmp is the name of the field that is being compared.
* The function should be registered through TVMFFITypeRegisterAttr via reflection::TypeAttrDef.
*/
kTVMFFISEqHashKindCustomTreeNode = 6,
#ifdef __cplusplus
};
#else
Expand Down Expand Up @@ -539,7 +560,9 @@ typedef struct {
/*
* \brief Column array that stores extra attributes about types
*
* The attributes stored in column arrays that can be looked up by type index.
* The attributes stored in a column array that can be looked up by type index.
* Note that the TypeAttr behaves like type_traits so column[T] so not contain
* attributes from base classes.
*
* \note
* \sa TVMFFIRegisterTypeAttr
Expand Down
2 changes: 2 additions & 0 deletions ffi/src/ffi/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ class TypeTable {
column_index = type_attr_columns_.size();
type_attr_columns_.emplace_back(std::make_unique<TypeAttrColumnData>());
type_attr_name_to_column_index_.Set(name_str, column_index);
} else {
column_index = (*it).second;
}
TypeAttrColumnData* column = type_attr_columns_[column_index].get();
if (column->data_.size() < static_cast<size_t>(type_index + 1)) {
Expand Down
62 changes: 53 additions & 9 deletions ffi/src/ffi/reflection/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,7 @@ class StructEqualHandler {
}

bool success = true;
if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
// we are in a free var case that is not yet mapped.
// in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be set
if (!lhs.same_as(rhs) && !map_free_vars_) {
success = false;
}
} else {
if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
// We recursively compare the fields the object
ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) {
// skip fields that are marked as structural eq hash ignore
Expand Down Expand Up @@ -158,11 +152,57 @@ class StructEqualHandler {
return false;
}
});
} else {
static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__");
// run custom equal function defined via __s_equal__ type attribute
if (s_equal_callback_ == nullptr) {
s_equal_callback_ = ffi::Function::FromTyped(
[this](AnyView lhs, AnyView rhs, bool def_region, AnyView field_name) {
// NOTE: we explicitly make field_name as AnyView to avoid copy overhead initially
// and only cast to string if mismatch happens
bool success = true;
if (def_region) {
bool allow_free_var = true;
std::swap(allow_free_var, map_free_vars_);
success = CompareAny(lhs, rhs);
std::swap(allow_free_var, map_free_vars_);
} else {
success = CompareAny(lhs, rhs);
}
if (!success) {
if (mismatch_lhs_reverse_path_ != nullptr) {
String field_name_str = field_name.cast<String>();
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
}
}
return success;
});
}
TVM_FFI_ICHECK(custom_s_equal[type_info->type_index] != nullptr)
<< "TypeAttr `__s_equal__` is not registered for type `" << String(type_info->type_key)
<< "`";
success = custom_s_equal[type_info->type_index]
.cast<ffi::Function>()(lhs, rhs, s_equal_callback_)
.cast<bool>();
}

if (success) {
if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
// we are in a free var case that is not yet mapped.
// in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be
// set
if (lhs.same_as(rhs) || map_free_vars_) {
// record the equality
equal_map_lhs_[lhs] = rhs;
equal_map_rhs_[rhs] = lhs;
return true;
} else {
return false;
}
}
// if we have a success mapping and in graph/var mode, record the equality mapping
if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode ||
structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
// record the equality
equal_map_lhs_[lhs] = rhs;
equal_map_rhs_[rhs] = lhs;
Expand Down Expand Up @@ -306,6 +346,8 @@ class StructEqualHandler {
// the root lhs for result printing
std::vector<AccessStep>* mismatch_lhs_reverse_path_ = nullptr;
std::vector<AccessStep>* mismatch_rhs_reverse_path_ = nullptr;
// lazily initialize custom equal function
ffi::Function s_equal_callback_ = nullptr;
// map from lhs to rhs
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_;
// map from rhs to lhs
Expand Down Expand Up @@ -342,6 +384,8 @@ TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("ffi.reflection.GetFirstStructuralMismatch",
StructuralEqual::GetFirstMismatch);
// ensure the type attribute column is presented in the system even if it is empty.
refl::EnsureTypeAttrColumn("__s_equal__");
});

} // namespace reflection
Expand Down
52 changes: 39 additions & 13 deletions ffi/src/ffi/reflection/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,7 @@ class StructuralHashHandler {

// compute the hash value
uint64_t hash_value = obj->GetTypeKeyHash();
if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
if (map_free_vars_) {
// use lexical order of free var and its type
hash_value = details::StableHashCombine(hash_value, free_var_counter_++);
} else {
// Fallback to pointer hash, we are not mapping free var.
return std::hash<const Object*>()(obj.get());
}
} else {
if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
// go over the content and hash the fields
ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
// skip fields that are marked as structural eq hash ignore
Expand All @@ -126,12 +118,43 @@ class StructuralHashHandler {
}
}
});
// if it is a DAG node, also record the lexical order of graph counter
// this helps to distinguish DAG from trees.
if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
hash_value = details::StableHashCombine(hash_value, graph_node_counter_++);
} else {
static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__");
TVM_FFI_ICHECK(custom_s_hash[type_info->type_index] != nullptr)
<< "TypeAttr `__s_hash__` is not registered for type `" << String(type_info->type_key)
<< "`";
if (s_hash_callback_ == nullptr) {
s_hash_callback_ = ffi::Function::FromTyped([this](AnyView val, bool def_region) {
if (def_region) {
bool allow_free_var = true;
std::swap(allow_free_var, map_free_vars_);
uint64_t hash_value = HashAny(val);
std::swap(allow_free_var, map_free_vars_);
return hash_value;
} else {
return HashAny(val);
}
});
}
hash_value = custom_s_hash[type_info->type_index]
.cast<ffi::Function>()(obj, hash_value, s_hash_callback_)
.cast<uint64_t>();
}

if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
if (map_free_vars_) {
// use lexical order of free var and its type
hash_value = details::StableHashCombine(hash_value, free_var_counter_++);
} else {
// Fallback to pointer hash, we are not mapping free var.
hash_value = std::hash<const Object*>()(obj.get());
}
}
// if it is a DAG node, also record the lexical order of graph counter
// this helps to distinguish DAG from trees.
if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
hash_value = details::StableHashCombine(hash_value, graph_node_counter_++);
}
// record the hash value for this object
hash_memo_[obj] = hash_value;
return hash_value;
Expand Down Expand Up @@ -244,6 +267,8 @@ class StructuralHashHandler {
uint32_t free_var_counter_{0};
// graph node counter.
uint32_t graph_node_counter_{0};
// lazily initialize custom hash function
ffi::Function s_hash_callback_ = nullptr;
// map from lhs to rhs
std::unordered_map<ObjectRef, uint64_t, ObjectPtrHash, ObjectPtrEqual> hash_memo_;
};
Expand All @@ -258,6 +283,7 @@ uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool skip_nd
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("ffi.reflection.StructuralHash", StructuralHash::Hash);
refl::EnsureTypeAttrColumn("__s_hash__");
});

} // namespace reflection
Expand Down
6 changes: 6 additions & 0 deletions ffi/tests/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
file(GLOB _test_sources "${CMAKE_CURRENT_SOURCE_DIR}/test*.cc")
file(GLOB _test_extra_sources "${CMAKE_CURRENT_SOURCE_DIR}/extra/test*.cc")

if (TVM_FFI_USE_EXTRA_CXX_API)
list(APPEND _test_sources ${_test_extra_sources})
endif()

add_executable(
tvm_ffi_tests
EXCLUDE_FROM_ALL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/ffi/string.h>

#include "./testing_object.h"
#include "../testing_object.h"

namespace {

Expand Down Expand Up @@ -169,4 +169,30 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) {
EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
}

TEST(StructuralEqualHash, CustomTreeNode) {
TVar x = TVar("x");
TVar y = TVar("y");
// comment fields are ignored
TCustomFunc fa = TCustomFunc({x}, {TInt(1), x}, "comment a");
TCustomFunc fb = TCustomFunc({y}, {TInt(1), y}, "comment b");

TCustomFunc fc = TCustomFunc({x}, {TInt(1), TInt(2)}, "comment c");

EXPECT_TRUE(refl::StructuralEqual()(fa, fb));
EXPECT_EQ(refl::StructuralHash()(fa), refl::StructuralHash()(fb));

EXPECT_FALSE(refl::StructuralEqual()(fa, fc));
auto diff_fa_fc = refl::StructuralEqual::GetFirstMismatch(fa, fc);
auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({
refl::AccessStep::ObjectField("body"),
refl::AccessStep::ArrayIndex(1),
}),
refl::AccessPath({
refl::AccessStep::ObjectField("body"),
refl::AccessStep::ArrayIndex(1),
}));
EXPECT_TRUE(diff_fa_fc.has_value());
EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
}

} // namespace
1 change: 1 addition & 0 deletions ffi/tests/cpp/test_reflection_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
TPrimExprObj::RegisterReflection();
TVarObj::RegisterReflection();
TFuncObj::RegisterReflection();
TCustomFuncObj::RegisterReflection();

refl::ObjectDef<TestObjA>().def_ro("x", &TestObjA::x).def_rw("y", &TestObjA::y);
refl::ObjectDef<TestObjADerived>().def_ro("z", &TestObjADerived::z);
Expand Down
57 changes: 56 additions & 1 deletion ffi/tests/cpp/testing_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class TVarObj : public Object {

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TVarObj>().def_ro("name", &TVarObj::name);
refl::ObjectDef<TVarObj>().def_ro("name", &TVarObj::name,
refl::AttachFieldFlag::SEqHashIgnore());
}

static constexpr const char* _type_key = "test.Var";
Expand Down Expand Up @@ -204,6 +205,60 @@ class TFunc : public ObjectRef {
TVM_FFI_DEFINE_OBJECT_REF_METHODS(TFunc, ObjectRef, TFuncObj);
};

class TCustomFuncObj : public Object {
public:
Array<TVar> params;
Array<ObjectRef> body;
String comment;

TCustomFuncObj(Array<TVar> params, Array<ObjectRef> body, String comment)
: params(params), body(body), comment(comment) {}

bool SEqual(const TCustomFuncObj* other,
ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> cmp) const {
if (!cmp(params, other->params, true, "params")) {
std::cout << "custom s_equal failed params" << std::endl;
return false;
}
if (!cmp(body, other->body, false, "body")) {
std::cout << "custom s_equal failed body" << std::endl;
return false;
}
return true;
}

uint64_t SHash(uint64_t type_key_hash, ffi::TypedFunction<uint64_t(AnyView, bool)> hash) const {
uint64_t hash_value = type_key_hash;
hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(params, true));
hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(body, false));
return hash_value;
}

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TCustomFuncObj>()
.def_ro("params", &TCustomFuncObj::params)
.def_ro("body", &TCustomFuncObj::body)
.def_ro("comment", &TCustomFuncObj::comment);
refl::TypeAttrDef<TCustomFuncObj>()
.def("__s_equal__", &TCustomFuncObj::SEqual)
.def("__s_hash__", &TCustomFuncObj::SHash);
}

static constexpr const char* _type_key = "test.CustomFunc";
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindCustomTreeNode;
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TCustomFuncObj, Object);
};

class TCustomFunc : public ObjectRef {
public:
explicit TCustomFunc(Array<TVar> params, Array<ObjectRef> body, String comment) {
data_ = make_object<TCustomFuncObj>(params, body, comment);
}

TVM_FFI_DEFINE_OBJECT_REF_METHODS(TCustomFunc, ObjectRef, TCustomFuncObj);
};

} // namespace testing

template <>
Expand Down
Loading