Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ set(tvm_ffi_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
)

if (TVM_FFI_USE_EXTRA_CXX_API)
list(APPEND tvm_ffi_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_equal.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc"
)
endif()

Expand Down
21 changes: 0 additions & 21 deletions ffi/include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,6 @@
#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default")))
#endif

/*!
* \brief Marks the API as extra c++ api that is defined in cc files.
*
* These APIs are extra features that depend on, but are not required to
* support essential core functionality, such as function calling and object
* access.
*
* They are implemented in cc files to reduce compile-time overhead.
* The input/output only uses POD/Any/ObjectRef for ABI stability.
* However, these extra APIs may have an issue across MSVC/Itanium ABI,
*
* Related features are also available through reflection based function
* that is fully based on C API
*
* The project aims to minimize the number of extra C++ APIs and only
* restrict the use to non-core functionalities.
*/
#ifndef TVM_FFI_EXTRA_CXX_API
#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL
#endif

#ifdef __cplusplus
extern "C" {
#endif
Expand Down
48 changes: 48 additions & 0 deletions ffi/include/tvm/ffi/extra/base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ffi/extra/base.h
* \brief Base header for Extra API.
*
* The extra APIs contains a minmal set of extra APIs that are not
* required to support essential core functionality.
*/
#ifndef TVM_FFI_EXTRA_BASE_H_
#define TVM_FFI_EXTRA_BASE_H_

#include <tvm/ffi/c_api.h>

/*!
* \brief Marks the API as extra c++ api that is defined in cc files.
*
* They are implemented in cc files to reduce compile-time overhead.
* The input/output only uses POD/Any/ObjectRef for ABI stability.
* However, these extra APIs may have an issue across MSVC/Itanium ABI,
*
* Related features are also available through reflection based function
* that is fully based on C API
*
* The project aims to minimize the number of extra C++ APIs to keep things
* lightweight and restrict the use to non-core functionalities.
*/
#ifndef TVM_FFI_EXTRA_CXX_API
#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL
#endif

#endif // TVM_FFI_EXTRA_BASE_H_
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
* under the License.
*/
/*!
* \file tvm/ffi/reflection/structural_equal.h
* \file tvm/ffi/extra/structural_equal.h
* \brief Structural equal implementation
*/
#ifndef TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
#define TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
#ifndef TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_
#define TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/extra/base.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/access_path.h>

namespace tvm {
namespace ffi {
namespace reflection {
/*
* \brief Structural equality comparators
*/
Expand Down Expand Up @@ -59,7 +59,7 @@ class StructuralEqual {
* \return If comparison fails, return the first mismatch AccessPath pair,
* otherwise return std::nullopt.
*/
TVM_FFI_EXTRA_CXX_API static Optional<AccessPathPair> GetFirstMismatch(
TVM_FFI_EXTRA_CXX_API static Optional<reflection::AccessPathPair> GetFirstMismatch(
const Any& lhs, const Any& rhs, bool map_free_vars = false,
bool skip_ndarray_content = false);

Expand All @@ -74,7 +74,6 @@ class StructuralEqual {
}
};

} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
* under the License.
*/
/*!
* \file tvm/ffi/reflection/structural_hash.h
* \file tvm/ffi/extra/structural_hash.h
* \brief Structural hash
*/
#ifndef TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
#define TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
#ifndef TVM_FFI_EXTRA_STRUCTURAL_HASH_H_
#define TVM_FFI_EXTRA_STRUCTURAL_HASH_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/extra/base.h>

namespace tvm {
namespace ffi {
namespace reflection {

/*
* \brief Structural hash
Expand All @@ -52,7 +52,6 @@ class StructuralHash {
TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return Hash(value); }
};

} // namespace reflection
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_
18 changes: 9 additions & 9 deletions ffi/include/tvm/ffi/reflection/access_path.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ namespace reflection {

enum class AccessKind : int32_t {
kObjectField = 0,
kArrayIndex = 1,
kMapKey = 2,
kArrayItem = 1,
kMapItem = 2,
// the following two are used for error reporting when
// the supposed access field is not available
kArrayIndexMissing = 3,
kMapKeyMissing = 4,
kArrayItemMissing = 3,
kMapItemMissing = 4,
};

/*!
Expand Down Expand Up @@ -86,15 +86,15 @@ class AccessStep : public ObjectRef {
return AccessStep(AccessKind::kObjectField, field_name);
}

static AccessStep ArrayIndex(int64_t index) { return AccessStep(AccessKind::kArrayIndex, index); }
static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); }

static AccessStep ArrayIndexMissing(int64_t index) {
return AccessStep(AccessKind::kArrayIndexMissing, index);
static AccessStep ArrayItemMissing(int64_t index) {
return AccessStep(AccessKind::kArrayItemMissing, index);
}

static AccessStep MapKey(Any key) { return AccessStep(AccessKind::kMapKey, key); }
static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, key); }

static AccessStep MapKeyMissing(Any key) { return AccessStep(AccessKind::kMapKeyMissing, key); }
static AccessStep MapItemMissing(Any key) { return AccessStep(AccessKind::kMapItemMissing, key); }

TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/ndarray.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/structural_equal.h>
#include <tvm/ffi/string.h>

#include <cmath>
#include <unordered_map>

namespace tvm {
namespace ffi {
namespace reflection {

/**
* \brief Internal Handler class for structural equal comparison.
Expand Down Expand Up @@ -135,11 +134,11 @@ class StructEqualHandler {
bool success = true;
if (custom_s_equal[type_info->type_index] == nullptr) {
// We recursively compare the fields the object
ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) {
reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) {
// skip fields that are marked as structural eq hash ignore
if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) return false;
// get the field value from both side
FieldGetter getter(field_info);
reflection::FieldGetter getter(field_info);
Any lhs_value = getter(lhs);
Any rhs_value = getter(rhs);
// field is in def region, enable free var mapping
Expand All @@ -155,9 +154,9 @@ class StructEqualHandler {
// record the first mismatching field if we sub-rountine compare failed
if (mismatch_lhs_reverse_path_ != nullptr) {
mismatch_lhs_reverse_path_->emplace_back(
AccessStep::ObjectField(String(field_info->name)));
reflection::AccessStep::ObjectField(String(field_info->name)));
mismatch_rhs_reverse_path_->emplace_back(
AccessStep::ObjectField(String(field_info->name)));
reflection::AccessStep::ObjectField(String(field_info->name)));
}
// return true to indicate early stop
return true;
Expand Down Expand Up @@ -185,8 +184,10 @@ class StructEqualHandler {
if (!success) {
if (mismatch_lhs_reverse_path_ != nullptr) {
String field_name_str = field_name.cast<String>();
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
mismatch_lhs_reverse_path_->emplace_back(
reflection::AccessStep::ObjectField(field_name_str));
mismatch_rhs_reverse_path_->emplace_back(
reflection::AccessStep::ObjectField(field_name_str));
}
}
return success;
Expand Down Expand Up @@ -235,16 +236,16 @@ class StructEqualHandler {
auto it = rhs.find(rhs_key);
if (it == rhs.end()) {
if (mismatch_lhs_reverse_path_ != nullptr) {
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(rhs_key));
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(rhs_key));
}
return false;
}
// now recursively compare value
if (!CompareAny(kv.second, (*it).second)) {
if (mismatch_lhs_reverse_path_ != nullptr) {
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(rhs_key));
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(rhs_key));
}
return false;
}
Expand All @@ -258,8 +259,8 @@ class StructEqualHandler {
auto it = lhs.find(lhs_key);
if (it == lhs.end()) {
if (mismatch_lhs_reverse_path_ != nullptr) {
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(lhs_key));
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(lhs_key));
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
}
return false;
}
Expand All @@ -276,20 +277,22 @@ class StructEqualHandler {
for (size_t i = 0; i < std::min(lhs.size(), rhs.size()); ++i) {
if (!CompareAny(lhs[i], rhs[i])) {
if (mismatch_lhs_reverse_path_ != nullptr) {
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i));
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i));
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i));
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i));
}
return false;
}
}
if (lhs.size() == rhs.size()) return true;
if (mismatch_lhs_reverse_path_ != nullptr) {
if (lhs.size() > rhs.size()) {
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(rhs.size()));
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(rhs.size()));
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(rhs.size()));
mismatch_rhs_reverse_path_->emplace_back(
reflection::AccessStep::ArrayItemMissing(rhs.size()));
} else {
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(lhs.size()));
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(lhs.size()));
mismatch_lhs_reverse_path_->emplace_back(
reflection::AccessStep::ArrayItemMissing(lhs.size()));
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(lhs.size()));
}
}
return false;
Expand Down Expand Up @@ -354,8 +357,8 @@ class StructEqualHandler {
// whether we compare ndarray data
bool skip_ndarray_content_{false};
// the root lhs for result printing
std::vector<AccessStep>* mismatch_lhs_reverse_path_ = nullptr;
std::vector<AccessStep>* mismatch_rhs_reverse_path_ = nullptr;
std::vector<reflection::AccessStep>* mismatch_lhs_reverse_path_ = nullptr;
std::vector<reflection::AccessStep>* mismatch_rhs_reverse_path_ = nullptr;
// lazily initialize custom equal function
ffi::Function s_equal_callback_ = nullptr;
// map from lhs to rhs
Expand All @@ -372,32 +375,31 @@ bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars,
return handler.CompareAny(lhs, rhs);
}

Optional<AccessPathPair> StructuralEqual::GetFirstMismatch(const Any& lhs, const Any& rhs,
bool map_free_vars,
bool skip_ndarray_content) {
Optional<reflection::AccessPathPair> StructuralEqual::GetFirstMismatch(const Any& lhs,
const Any& rhs,
bool map_free_vars,
bool skip_ndarray_content) {
StructEqualHandler handler;
handler.map_free_vars_ = map_free_vars;
handler.skip_ndarray_content_ = skip_ndarray_content;
std::vector<AccessStep> lhs_reverse_path;
std::vector<AccessStep> rhs_reverse_path;
std::vector<reflection::AccessStep> lhs_reverse_path;
std::vector<reflection::AccessStep> rhs_reverse_path;
handler.mismatch_lhs_reverse_path_ = &lhs_reverse_path;
handler.mismatch_rhs_reverse_path_ = &rhs_reverse_path;
if (handler.CompareAny(lhs, rhs)) {
return std::nullopt;
}
AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend());
AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend());
return AccessPathPair(lhs_path, rhs_path);
reflection::AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend());
reflection::AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend());
return reflection::AccessPathPair(lhs_path, rhs_path);
}

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("ffi.reflection.GetFirstStructuralMismatch",
StructuralEqual::GetFirstMismatch);
refl::GlobalDef().def("ffi.GetFirstStructuralMismatch", StructuralEqual::GetFirstMismatch);
// ensure the type attribute column is presented in the system even if it is empty.
refl::EnsureTypeAttrColumn("__s_equal__");
});

} // namespace reflection
} // namespace ffi
} // namespace tvm
Loading
Loading