Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 84 additions & 30 deletions ffi/include/tvm/ffi/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class AnyView {
void reset() {
data_.type_index = TypeIndex::kTVMFFINone;
// invariance: always set the union padding part to 0
data_.zero_padding = 0;
data_.v_int64 = 0;
}
/*!
Expand All @@ -72,6 +73,7 @@ class AnyView {
// default constructors
AnyView() {
data_.type_index = TypeIndex::kTVMFFINone;
data_.zero_padding = 0;
data_.v_int64 = 0;
}
~AnyView() = default;
Expand All @@ -80,6 +82,7 @@ class AnyView {
AnyView& operator=(const AnyView&) = default;
AnyView(AnyView&& other) : data_(other.data_) {
other.data_.type_index = TypeIndex::kTVMFFINone;
other.data_.zero_padding = 0;
other.data_.v_int64 = 0;
}
TVM_FFI_INLINE AnyView& operator=(AnyView&& other) {
Expand Down Expand Up @@ -198,22 +201,19 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data,
if (data->type_index == TypeIndex::kTVMFFIRawStr) {
// convert raw string to owned string object
String temp(data->v_c_str);
data->type_index = TypeIndex::kTVMFFIStr;
data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
TypeTraits<String>::MoveToAny(std::move(temp), data);
} else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) {
// convert byte array to owned bytes object
Bytes temp(*static_cast<TVMFFIByteArray*>(data->v_ptr));
data->type_index = TypeIndex::kTVMFFIBytes;
data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
TypeTraits<Bytes>::MoveToAny(std::move(temp), data);
} else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) {
// convert rvalue ref to owned object
Object** obj_addr = static_cast<Object**>(data->v_ptr);
TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved";
ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(obj_addr[0]));
// set the rvalue ref to nullptr to avoid double move
obj_addr[0] = nullptr;
data->type_index = temp->type_index();
data->v_obj = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(temp));
TypeTraits<ObjectRef>::MoveToAny(std::move(temp), data);
}
}
}
Expand All @@ -239,6 +239,7 @@ class Any {
details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj);
}
data_.type_index = TVMFFITypeIndex::kTVMFFINone;
data_.zero_padding = 0;
data_.v_int64 = 0;
}
/*!
Expand All @@ -251,6 +252,7 @@ class Any {
// default constructors
Any() {
data_.type_index = TypeIndex::kTVMFFINone;
data_.zero_padding = 0;
data_.v_int64 = 0;
}
~Any() { this->reset(); }
Expand All @@ -262,6 +264,7 @@ class Any {
}
Any(Any&& other) : data_(other.data_) {
other.data_.type_index = TypeIndex::kTVMFFINone;
other.data_.zero_padding = 0;
other.data_.v_int64 = 0;
}
TVM_FFI_INLINE Any& operator=(const Any& other) {
Expand Down Expand Up @@ -408,7 +411,8 @@ class Any {
* \return True if the two Any are same type and value, false otherwise.
*/
TVM_FFI_INLINE bool same_as(const Any& other) const noexcept {
return data_.type_index == other.data_.type_index && data_.v_int64 == other.data_.v_int64;
return data_.type_index == other.data_.type_index &&
data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64;
}

/*
Expand Down Expand Up @@ -485,6 +489,7 @@ struct AnyUnsafe : public ObjectUnsafe {
TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any&& any) {
TVMFFIAny result = any.data_;
any.data_.type_index = TypeIndex::kTVMFFINone;
any.data_.zero_padding = 0;
any.data_.v_int64 = 0;
return result;
}
Expand All @@ -493,6 +498,7 @@ struct AnyUnsafe : public ObjectUnsafe {
Any any;
any.data_ = data;
data.type_index = TypeIndex::kTVMFFINone;
data.zero_padding = 0;
data.v_int64 = 0;
return any;
}
Expand Down Expand Up @@ -543,17 +549,24 @@ struct AnyHash {
* \return Hash code of a, string hash for strings and pointer address otherwise.
*/
uint64_t operator()(const Any& src) const {
uint64_t val_hash = [&]() -> uint64_t {
if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
const details::BytesObjBase* src_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
return details::StableHashBytes(src_str->data, src_str->size);
} else {
return src.data_.v_uint64;
}
}();
return details::StableHashCombine(src.data_.type_index, val_hash);
if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) {
// for small string, we use the same type key hash as normal string
// so heap allocated string and on stack string will have the same hash
return details::StableHashCombine(TypeIndex::kTVMFFIStr,
details::StableHashSmallStrBytes(&src.data_));
} else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) {
// use byte the same type key as bytes
return details::StableHashCombine(TypeIndex::kTVMFFIBytes,
details::StableHashSmallStrBytes(&src.data_));
} else if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
const details::BytesObjBase* src_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
return details::StableHashCombine(src.data_.type_index,
details::StableHashBytes(src_str->data, src_str->size));
} else {
return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64);
}
}
};

Expand All @@ -566,19 +579,60 @@ struct AnyEqual {
* \return String equality if both are strings, pointer address equality otherwise.
*/
bool operator()(const Any& lhs, const Any& rhs) const {
if (lhs.data_.type_index != rhs.data_.type_index) return false;
// byte equivalence
if (lhs.data_.v_int64 == rhs.data_.v_int64) return true;
// specialy handle string hash
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
const details::BytesObjBase* lhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
const details::BytesObjBase* rhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
// header with type index
const int64_t* lhs_as_int64 = reinterpret_cast<const int64_t*>(&lhs.data_);
const int64_t* rhs_as_int64 = reinterpret_cast<const int64_t*>(&rhs.data_);
static_assert(sizeof(TVMFFIAny) == 16);
// fast path, check byte equality
if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) {
return true;
}
// common false case type index match, in this case we only need to pay attention to string
// equality
if (lhs.data_.type_index == rhs.data_.type_index) {
// specialy handle string hash
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
const details::BytesObjBase* lhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
const details::BytesObjBase* rhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
}
return false;
} else {
// type_index mismatch, if index is not string, return false
if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr &&
lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) {
return false;
}
// small string and normal string comparison
if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) {
const details::BytesObjBase* lhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size,
rhs.data_.small_str_len);
}
if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) {
const details::BytesObjBase* rhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len,
rhs_str->size);
}
if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) {
const details::BytesObjBase* lhs_bytes =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size,
rhs.data_.small_str_len);
}
if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) {
const details::BytesObjBase* rhs_bytes =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len,
rhs_bytes->size);
}
return false;
}
return false;
}
};

Expand Down
17 changes: 16 additions & 1 deletion ffi/include/tvm/ffi/base_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key, const T& value) {
* \param size The size of the bytes.
* \return the hash value.
*/
TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
TVM_FFI_INLINE uint64_t StableHashBytes(const void* data_ptr, size_t size) {
const char* data = reinterpret_cast<const char*>(data_ptr);
const constexpr uint64_t kMultiplier = 1099511628211ULL;
const constexpr uint64_t kMod = 2147483647ULL;
union Union {
Expand Down Expand Up @@ -250,6 +251,20 @@ TVM_FFI_INLINE uint64_t StableHashBytes(const char* data, size_t size) {
return result;
}

/*!
* \brief Same as StableHashBytes, but for small string data.
* \param data The data pointer
* \return the hash value.
*/
TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const TVMFFIAny* data) {
if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) {
// fast path, no endian swap, simply hash as uint64_t
const constexpr uint64_t kMod = 2147483647ULL;
return data->v_uint64 % kMod;
}
return StableHashBytes(reinterpret_cast<const void*>(data), sizeof(data->v_uint64));
}

} // namespace details
} // namespace ffi
} // namespace tvm
Expand Down
48 changes: 34 additions & 14 deletions ffi/include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,7 @@ enum TVMFFITypeIndex : int32_t {
#else
typedef enum {
#endif
// [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin)
// N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array,
// which is not owned by TVMFFIAny. It is required that the following
// invariant holds:
// - `Any::type_index` is never `kTVMFFIRawStr`
// - `AnyView::type_index` can be `kTVMFFIRawStr`
//

/*
* \brief The root type of all FFI objects.
*
Expand All @@ -80,6 +74,13 @@ typedef enum {
* However, it may appear in field annotations during reflection.
*/
kTVMFFIAny = -1,
// [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin)
// N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array,
// which is not owned by TVMFFIAny. It is required that the following
// invariant holds:
// - `Any::type_index` is never `kTVMFFIRawStr`
// - `AnyView::type_index` can be `kTVMFFIRawStr`
//
/*! \brief None/nullptr value */
kTVMFFINone = 0,
/*! \brief POD int value */
Expand All @@ -96,12 +97,16 @@ typedef enum {
kTVMFFIDevice = 6,
/*! \brief DLTensor* */
kTVMFFIDLTensorPtr = 7,
/*! \brief const char**/
/*! \brief const char* */
kTVMFFIRawStr = 8,
/*! \brief TVMFFIByteArray* */
kTVMFFIByteArrayPtr = 9,
/*! \brief R-value reference to ObjectRef */
kTVMFFIObjectRValueRef = 10,
/*! \brief Small string on stack */
kTVMFFISmallStr = 11,
/*! \brief Small bytes on stack */
kTVMFFISmallBytes = 12,
/*! \brief Start of statically defined objects. */
kTVMFFIStaticObjectBegin = 64,
/*!
Expand Down Expand Up @@ -183,11 +188,17 @@ typedef struct TVMFFIAny {
* \note The type index of Object and Any are shared in FFI.
*/
int32_t type_index;
/*!
* \brief length for on-stack Any object, such as small-string
* \note This field is reserved for future compact.
*/
int32_t small_len;
union { // 4 bytes
/*! \brief padding, must set to zero for values other than small string. */
uint32_t zero_padding;
/*!
* \brief Length of small string, with a max value of 7.
*
* We keep small str to start at next 4 bytes to ensure alignment
* when accessing the small str content.
*/
uint32_t small_str_len;
};
union { // 8 bytes
int64_t v_int64; // integers
double v_float64; // floating-point numbers
Expand Down Expand Up @@ -823,7 +834,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType*

* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues.
*/
TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out);
TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out);

//------------------------------------------------------------
// Section: Backend noexcept functions for internal use
Expand Down Expand Up @@ -903,6 +914,15 @@ inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) {
return static_cast<TVMFFIObject*>(obj)->type_index;
}

/*!
* \brief Get the content of a small string in bytearray format.
* \param obj The object handle.
* \return The content of the small string in bytearray format.
*/
inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) {
return TVMFFIByteArray{value->v_bytes, static_cast<size_t>(value->small_str_len)};
}

/*!
* \brief Get the data pointer of a bytearray from a string or bytes object.
* \param obj The object handle.
Expand Down
1 change: 1 addition & 0 deletions ffi/include/tvm/ffi/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/optional.h>

#include <utility>

Expand Down
2 changes: 2 additions & 0 deletions ffi/include/tvm/ffi/container/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ class VariantBase<true> : public ObjectRef {
TVMFFIAny any_data;
if (data_ == nullptr) {
any_data.type_index = TypeIndex::kTVMFFINone;
any_data.zero_padding = 0;
any_data.v_int64 = 0;
} else {
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
any_data.type_index = data_->type_index();
any_data.zero_padding = 0;
any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr<Object>(data_);
}
return AnyView::CopyFromTVMFFIAny(any_data);
Expand Down
9 changes: 6 additions & 3 deletions ffi/include/tvm/ffi/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,15 @@ inline const char* DLDataTypeCodeAsCStr(DLDataTypeCode type_code) { // NOLINT(*

inline DLDataType StringToDLDataType(const String& str) {
DLDataType out;
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(str.get(), &out));
TVMFFIByteArray data{str.data(), str.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(&data, &out));
return out;
}

inline String DLDataTypeToString(DLDataType dtype) {
TVMFFIObjectHandle out;
TVMFFIAny out;
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out));
return String(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(out)));
return TypeTraits<String>::MoveFromAnyAfterCheck(&out);
}

// DLDataType
Expand All @@ -134,13 +135,15 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase {
// clear padding part to ensure the equality check can always check the v_uint64 part
result->v_uint64 = 0;
result->type_index = TypeIndex::kTVMFFIDataType;
result->zero_padding = 0;
result->v_dtype = src;
}

TVM_FFI_INLINE static void MoveToAny(DLDataType src, TVMFFIAny* result) {
// clear padding part to ensure the equality check can always check the v_uint64 part
result->v_uint64 = 0;
result->type_index = TypeIndex::kTVMFFIDataType;
result->zero_padding = 0;
result->v_dtype = src;
}

Expand Down
2 changes: 2 additions & 0 deletions ffi/include/tvm/ffi/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ struct StaticTypeKey {
static constexpr const char* kTVMFFIFunction = "ffi.Function";
static constexpr const char* kTVMFFIArray = "ffi.Array";
static constexpr const char* kTVMFFIMap = "ffi.Map";
static constexpr const char* kTVMFFISmallStr = "ffi.SmallStr";
static constexpr const char* kTVMFFISmallBytes = "ffi.SmallBytes";
};

/*!
Expand Down
Loading
Loading