diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index ba0237f0e434..d6c56c8112b0 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -26,6 +26,7 @@ project( option(TVM_FFI_BUILD_TESTS "Adding test targets." OFF) option(TVM_FFI_USE_LIBBACKTRACE "Enable libbacktrace" ON) +option(TVM_FFI_USE_EXTRA_CXX_API "Enable extra CXX API in shared lib" ON) option(TVM_FFI_BACKTRACE_ON_SEGFAULT "Set signal handler to print traceback on segfault" ON) include(cmake/Utils/CxxWarning.cmake) @@ -47,7 +48,8 @@ target_include_directories(tvm_ffi_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR} target_link_libraries(tvm_ffi_header INTERFACE dlpack_header) ########## Target: `tvm_ffi` ########## -add_library(tvm_ffi_objs OBJECT + +set(tvm_ffi_objs_sources "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback_win.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/object.cc" @@ -57,10 +59,18 @@ add_library(tvm_ffi_objs OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc" ) + +if (TVM_FFI_USE_EXTRA_CXX_API) + list(APPEND tvm_ffi_objs_sources + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc" + ) +endif() + +add_library(tvm_ffi_objs OBJECT ${tvm_ffi_objs_sources}) + set_target_properties( tvm_ffi_objs PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 59b687759846..e2de610a5df7 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -424,6 +424,27 @@ typedef enum { * is only an unique copy of each value. */ kTVMFFISEqHashKindUniqueInstance = 5, + /*! + * \brief provide custom __s_equal__ and __s_hash__ functions through TypeAttrColumn. + * + * The function signatures are(defined via ffi::Function) + * + * \code + * bool __s_equal__( + * ObjectRefType self, ObjectRefType other, + * ffi::TypedFunction cmp, + * ); + * + * uint64_t __s_hash__( + * ObjectRefType self, uint64_t type_key_hash, + * ffi::TypedFunction hash + * ); + * \endcode + * + * Where the extra string field in cmp is the name of the field that is being compared. + * The function should be registered through TVMFFITypeRegisterAttr via reflection::TypeAttrDef. + */ + kTVMFFISEqHashKindCustomTreeNode = 6, #ifdef __cplusplus }; #else @@ -539,7 +560,9 @@ typedef struct { /* * \brief Column array that stores extra attributes about types * - * The attributes stored in column arrays that can be looked up by type index. + * The attributes stored in a column array that can be looked up by type index. + * Note that the TypeAttr behaves like type_traits so column[T] so not contain + * attributes from base classes. * * \note * \sa TVMFFIRegisterTypeAttr diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 863b43b3e0cf..4abe933d4db8 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -247,6 +247,8 @@ class TypeTable { column_index = type_attr_columns_.size(); type_attr_columns_.emplace_back(std::make_unique()); type_attr_name_to_column_index_.Set(name_str, column_index); + } else { + column_index = (*it).second; } TypeAttrColumnData* column = type_attr_columns_[column_index].get(); if (column->data_.size() < static_cast(type_index + 1)) { diff --git a/ffi/src/ffi/reflection/structural_equal.cc b/ffi/src/ffi/reflection/structural_equal.cc index 671664435098..03cbdd95bee9 100644 --- a/ffi/src/ffi/reflection/structural_equal.cc +++ b/ffi/src/ffi/reflection/structural_equal.cc @@ -119,13 +119,7 @@ class StructEqualHandler { } bool success = true; - if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { - // we are in a free var case that is not yet mapped. - // in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be set - if (!lhs.same_as(rhs) && !map_free_vars_) { - success = false; - } - } else { + if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) { // We recursively compare the fields the object ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) { // skip fields that are marked as structural eq hash ignore @@ -158,11 +152,57 @@ class StructEqualHandler { return false; } }); + } else { + static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__"); + // run custom equal function defined via __s_equal__ type attribute + if (s_equal_callback_ == nullptr) { + s_equal_callback_ = ffi::Function::FromTyped( + [this](AnyView lhs, AnyView rhs, bool def_region, AnyView field_name) { + // NOTE: we explicitly make field_name as AnyView to avoid copy overhead initially + // and only cast to string if mismatch happens + bool success = true; + if (def_region) { + bool allow_free_var = true; + std::swap(allow_free_var, map_free_vars_); + success = CompareAny(lhs, rhs); + std::swap(allow_free_var, map_free_vars_); + } else { + success = CompareAny(lhs, rhs); + } + if (!success) { + if (mismatch_lhs_reverse_path_ != nullptr) { + String field_name_str = field_name.cast(); + mismatch_lhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str)); + mismatch_rhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str)); + } + } + return success; + }); + } + TVM_FFI_ICHECK(custom_s_equal[type_info->type_index] != nullptr) + << "TypeAttr `__s_equal__` is not registered for type `" << String(type_info->type_key) + << "`"; + success = custom_s_equal[type_info->type_index] + .cast()(lhs, rhs, s_equal_callback_) + .cast(); } + if (success) { + if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { + // we are in a free var case that is not yet mapped. + // in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be + // set + if (lhs.same_as(rhs) || map_free_vars_) { + // record the equality + equal_map_lhs_[lhs] = rhs; + equal_map_rhs_[rhs] = lhs; + return true; + } else { + return false; + } + } // if we have a success mapping and in graph/var mode, record the equality mapping - if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode || - structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { + if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) { // record the equality equal_map_lhs_[lhs] = rhs; equal_map_rhs_[rhs] = lhs; @@ -306,6 +346,8 @@ class StructEqualHandler { // the root lhs for result printing std::vector* mismatch_lhs_reverse_path_ = nullptr; std::vector* mismatch_rhs_reverse_path_ = nullptr; + // lazily initialize custom equal function + ffi::Function s_equal_callback_ = nullptr; // map from lhs to rhs std::unordered_map equal_map_lhs_; // map from rhs to lhs @@ -342,6 +384,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ffi.reflection.GetFirstStructuralMismatch", StructuralEqual::GetFirstMismatch); + // ensure the type attribute column is presented in the system even if it is empty. + refl::EnsureTypeAttrColumn("__s_equal__"); }); } // namespace reflection diff --git a/ffi/src/ffi/reflection/structural_hash.cc b/ffi/src/ffi/reflection/structural_hash.cc index fc4479044f16..ba47de5146d4 100644 --- a/ffi/src/ffi/reflection/structural_hash.cc +++ b/ffi/src/ffi/reflection/structural_hash.cc @@ -99,15 +99,7 @@ class StructuralHashHandler { // compute the hash value uint64_t hash_value = obj->GetTypeKeyHash(); - if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { - if (map_free_vars_) { - // use lexical order of free var and its type - hash_value = details::StableHashCombine(hash_value, free_var_counter_++); - } else { - // Fallback to pointer hash, we are not mapping free var. - return std::hash()(obj.get()); - } - } else { + if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) { // go over the content and hash the fields ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { // skip fields that are marked as structural eq hash ignore @@ -126,12 +118,43 @@ class StructuralHashHandler { } } }); - // if it is a DAG node, also record the lexical order of graph counter - // this helps to distinguish DAG from trees. - if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) { - hash_value = details::StableHashCombine(hash_value, graph_node_counter_++); + } else { + static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__"); + TVM_FFI_ICHECK(custom_s_hash[type_info->type_index] != nullptr) + << "TypeAttr `__s_hash__` is not registered for type `" << String(type_info->type_key) + << "`"; + if (s_hash_callback_ == nullptr) { + s_hash_callback_ = ffi::Function::FromTyped([this](AnyView val, bool def_region) { + if (def_region) { + bool allow_free_var = true; + std::swap(allow_free_var, map_free_vars_); + uint64_t hash_value = HashAny(val); + std::swap(allow_free_var, map_free_vars_); + return hash_value; + } else { + return HashAny(val); + } + }); + } + hash_value = custom_s_hash[type_info->type_index] + .cast()(obj, hash_value, s_hash_callback_) + .cast(); + } + + if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { + if (map_free_vars_) { + // use lexical order of free var and its type + hash_value = details::StableHashCombine(hash_value, free_var_counter_++); + } else { + // Fallback to pointer hash, we are not mapping free var. + hash_value = std::hash()(obj.get()); } } + // if it is a DAG node, also record the lexical order of graph counter + // this helps to distinguish DAG from trees. + if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) { + hash_value = details::StableHashCombine(hash_value, graph_node_counter_++); + } // record the hash value for this object hash_memo_[obj] = hash_value; return hash_value; @@ -244,6 +267,8 @@ class StructuralHashHandler { uint32_t free_var_counter_{0}; // graph node counter. uint32_t graph_node_counter_{0}; + // lazily initialize custom hash function + ffi::Function s_hash_callback_ = nullptr; // map from lhs to rhs std::unordered_map hash_memo_; }; @@ -258,6 +283,7 @@ uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool skip_nd TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("ffi.reflection.StructuralHash", StructuralHash::Hash); + refl::EnsureTypeAttrColumn("__s_hash__"); }); } // namespace reflection diff --git a/ffi/tests/cpp/CMakeLists.txt b/ffi/tests/cpp/CMakeLists.txt index 429683600bf9..0c820fc80ea8 100644 --- a/ffi/tests/cpp/CMakeLists.txt +++ b/ffi/tests/cpp/CMakeLists.txt @@ -1,4 +1,10 @@ file(GLOB _test_sources "${CMAKE_CURRENT_SOURCE_DIR}/test*.cc") +file(GLOB _test_extra_sources "${CMAKE_CURRENT_SOURCE_DIR}/extra/test*.cc") + +if (TVM_FFI_USE_EXTRA_CXX_API) + list(APPEND _test_sources ${_test_extra_sources}) +endif() + add_executable( tvm_ffi_tests EXCLUDE_FROM_ALL diff --git a/ffi/tests/cpp/test_reflection_structural_equal_hash.cc b/ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc similarity index 85% rename from ffi/tests/cpp/test_reflection_structural_equal_hash.cc rename to ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc index 8646c43c6197..d3353b782d33 100644 --- a/ffi/tests/cpp/test_reflection_structural_equal_hash.cc +++ b/ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc @@ -26,7 +26,7 @@ #include #include -#include "./testing_object.h" +#include "../testing_object.h" namespace { @@ -169,4 +169,30 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) { EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); } +TEST(StructuralEqualHash, CustomTreeNode) { + TVar x = TVar("x"); + TVar y = TVar("y"); + // comment fields are ignored + TCustomFunc fa = TCustomFunc({x}, {TInt(1), x}, "comment a"); + TCustomFunc fb = TCustomFunc({y}, {TInt(1), y}, "comment b"); + + TCustomFunc fc = TCustomFunc({x}, {TInt(1), TInt(2)}, "comment c"); + + EXPECT_TRUE(refl::StructuralEqual()(fa, fb)); + EXPECT_EQ(refl::StructuralHash()(fa), refl::StructuralHash()(fb)); + + EXPECT_FALSE(refl::StructuralEqual()(fa, fc)); + auto diff_fa_fc = refl::StructuralEqual::GetFirstMismatch(fa, fc); + auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({ + refl::AccessStep::ObjectField("body"), + refl::AccessStep::ArrayIndex(1), + }), + refl::AccessPath({ + refl::AccessStep::ObjectField("body"), + refl::AccessStep::ArrayIndex(1), + })); + EXPECT_TRUE(diff_fa_fc.has_value()); + EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); +} + } // namespace diff --git a/ffi/tests/cpp/test_reflection_accessor.cc b/ffi/tests/cpp/test_reflection_accessor.cc index 6450bd67c1ec..aa3dfc5e923c 100644 --- a/ffi/tests/cpp/test_reflection_accessor.cc +++ b/ffi/tests/cpp/test_reflection_accessor.cc @@ -55,6 +55,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ TPrimExprObj::RegisterReflection(); TVarObj::RegisterReflection(); TFuncObj::RegisterReflection(); + TCustomFuncObj::RegisterReflection(); refl::ObjectDef().def_ro("x", &TestObjA::x).def_rw("y", &TestObjA::y); refl::ObjectDef().def_ro("z", &TestObjADerived::z); diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index 37fb7417c80d..63c2b42d4f77 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -158,7 +158,8 @@ class TVarObj : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("name", &TVarObj::name); + refl::ObjectDef().def_ro("name", &TVarObj::name, + refl::AttachFieldFlag::SEqHashIgnore()); } static constexpr const char* _type_key = "test.Var"; @@ -204,6 +205,60 @@ class TFunc : public ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS(TFunc, ObjectRef, TFuncObj); }; +class TCustomFuncObj : public Object { + public: + Array params; + Array body; + String comment; + + TCustomFuncObj(Array params, Array body, String comment) + : params(params), body(body), comment(comment) {} + + bool SEqual(const TCustomFuncObj* other, + ffi::TypedFunction cmp) const { + if (!cmp(params, other->params, true, "params")) { + std::cout << "custom s_equal failed params" << std::endl; + return false; + } + if (!cmp(body, other->body, false, "body")) { + std::cout << "custom s_equal failed body" << std::endl; + return false; + } + return true; + } + + uint64_t SHash(uint64_t type_key_hash, ffi::TypedFunction hash) const { + uint64_t hash_value = type_key_hash; + hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(params, true)); + hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(body, false)); + return hash_value; + } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("params", &TCustomFuncObj::params) + .def_ro("body", &TCustomFuncObj::body) + .def_ro("comment", &TCustomFuncObj::comment); + refl::TypeAttrDef() + .def("__s_equal__", &TCustomFuncObj::SEqual) + .def("__s_hash__", &TCustomFuncObj::SHash); + } + + static constexpr const char* _type_key = "test.CustomFunc"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindCustomTreeNode; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TCustomFuncObj, Object); +}; + +class TCustomFunc : public ObjectRef { + public: + explicit TCustomFunc(Array params, Array body, String comment) { + data_ = make_object(params, body, comment); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS(TCustomFunc, ObjectRef, TCustomFuncObj); +}; + } // namespace testing template <>