Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions ffi/include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,22 @@ typedef struct {
* \brief Optional meta-data for structural eq/hash.
*/
TVMFFISEqHashKind structural_eq_hash_kind;
} TVMFFITypeExtraInfo;
} TVMFFITypeMetadata;

/*
* \brief Column array that stores extra attributes about types
*
* The attributes stored in column arrays that can be looked up by type index.
*
* \note
* \sa TVMFFIRegisterTypeAttr
*/
typedef struct {
/*! \brief The data of the column. */
const TVMFFIAny* data;
/*! \brief The size of the column. */
size_t size;
} TVMFFITypeAttrColumn;

/*!
* \brief Runtime type information for object type checking.
Expand Down Expand Up @@ -567,7 +582,7 @@ typedef struct TVMFFITypeInfo {
/*! \brief The reflection method. */
const TVMFFIMethodInfo* methods;
/*! \brief The extra information of the type. */
const TVMFFITypeExtraInfo* extra_info;
const TVMFFITypeMetadata* metadata;
} TVMFFITypeInfo;

//------------------------------------------------------------
Expand Down Expand Up @@ -738,11 +753,27 @@ TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodI
/*!
* \brief Register type creator information for runtime reflection.
* \param type_index The type index
* \param extra_info The extra information to be registered.
* \param metadata The extra information to be registered.
* \return 0 when success, nonzero when failure happens
*/
TVM_FFI_DLL int TVMFFITypeRegisterExtraInfo(int32_t type_index,
const TVMFFITypeExtraInfo* extra_info);
TVM_FFI_DLL int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata);

/*!
* \brief Register extra type attributes that can be looked up during runtime.
* \param type_index The type index
* \param attr_value The attribute value to be registered.
* \return 0 when success, nonzero when failure happens
*/
TVM_FFI_DLL int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* attr_name,
const TVMFFIAny* attr_value);

/*!
* \brief Get the type attribute column by name.
* \param attr_name The name of the attribute.
* \return The pointer to the type attribute column.
* \return NULL if the attribute was not registered in the system
*/
TVM_FFI_DLL const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* attr_name);

//------------------------------------------------------------
// Section: DLPack support APIs
Expand Down
23 changes: 23 additions & 0 deletions ffi/include/tvm/ffi/reflection/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,29 @@ class FieldSetter {
const TVMFFIFieldInfo* field_info_;
};

class TypeAttrColumn {
public:
explicit TypeAttrColumn(std::string_view attr_name) {
TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()};
column_ = TVMFFIGetTypeAttrColumn(&attr_name_array);
if (column_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name;
}
}

AnyView operator[](int32_t type_index) const {
size_t tindex = static_cast<size_t>(type_index);
if (tindex >= column_->size) {
return AnyView();
}
const AnyView* any_view_data = reinterpret_cast<const AnyView*>(column_->data);
return any_view_data[tindex];
}

private:
const TVMFFITypeAttrColumn* column_;
};

/*!
* \brief helper function to get reflection method info by type key and method name
*
Expand Down
68 changes: 65 additions & 3 deletions ffi/include/tvm/ffi/reflection/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class ReflectionDefBase {
}

template <typename T>
TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeExtraInfo* info, const T& value) {
TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata* info, const T& value) {
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
}
Expand Down Expand Up @@ -381,7 +381,7 @@ class ObjectDef : public ReflectionDefBase {
private:
template <typename... ExtraArgs>
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
TVMFFITypeExtraInfo info;
TVMFFITypeMetadata info;
info.total_size = sizeof(Class);
info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
info.creator = nullptr;
Expand All @@ -391,7 +391,7 @@ class ObjectDef : public ReflectionDefBase {
}
// apply extra info traits
((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info));
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info));
}

template <typename T, typename BaseClass, typename... ExtraArgs>
Expand Down Expand Up @@ -446,6 +446,68 @@ class ObjectDef : public ReflectionDefBase {
const char* type_key_;
};

template <typename Class, typename = std::enable_if_t<std::is_base_of_v<Object, Class>>>
class TypeAttrDef : public ReflectionDefBase {
public:
template <typename... ExtraArgs>
explicit TypeAttrDef(ExtraArgs&&... extra_args)
: type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {}

/*
* \brief Define a function-valued type attribute.
*
* \tparam Func The function type.
*
* \param name The name of the function.
* \param func The function to be registered.
*
* \return The TypeAttrDef object.
*/
template <typename Func>
TypeAttrDef& def(const char* name, Func&& func) {
TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
ffi::Function ffi_func =
GetMethod<Class>(std::string(type_key_) + "." + name, std::forward<Func>(func));
TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny();
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
return *this;
}

/*
* \brief Define a constant-valued type attribute.
*
* \tparam T The type of the value.
*
* \param name The name of the attribute.
* \param value The value of the attribute.
*
* \return The TypeAttrDef object.
*/
template <typename T>
TypeAttrDef& attr(const char* name, T value) {
TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny();
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
return *this;
}

private:
int32_t type_index_;
const char* type_key_;
};

/*!
* \brief Ensure the type attribute column is presented in the system.
*
* \param name The name of the type attribute.
*/
inline void EnsureTypeAttrColumn(std::string_view name) {
TVMFFIByteArray name_array = {name.data(), name.size()};
AnyView any_view(nullptr);
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array,
reinterpret_cast<const TVMFFIAny*>(&any_view)));
}

} // namespace reflection
} // namespace ffi
} // namespace tvm
Expand Down
76 changes: 63 additions & 13 deletions ffi/src/ffi/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TypeTable {
/*! \brief type methods informaton */
std::vector<TVMFFIMethodInfo> type_methods_data;
/*! \brief extra information */
TVMFFITypeExtraInfo extra_info_data;
TVMFFITypeMetadata metadata_data;
// NOTE: the indices in [index, index + num_reserved_slots) are
// reserved for the child-class of this type.
/*! \brief Total number of slots reserved for the type and its children. */
Expand Down Expand Up @@ -100,10 +100,14 @@ class TypeTable {
this->num_methods = 0;
this->fields = nullptr;
this->methods = nullptr;
this->extra_info = nullptr;
this->metadata = nullptr;
}
};

struct TypeAttrColumnData : public TVMFFITypeAttrColumn {
std::vector<Any> data_;
};

int32_t GetOrAllocTypeIndex(String type_key, int32_t static_type_index, int32_t type_depth,
int32_t num_child_slots, bool child_slots_can_overflow,
int32_t parent_type_index) {
Expand Down Expand Up @@ -219,19 +223,49 @@ class TypeTable {
entry->num_methods = static_cast<int32_t>(entry->type_methods_data.size());
}

void RegisterTypeExtraInfo(int32_t type_index, const TVMFFITypeExtraInfo* extra_info) {
void RegisterTypeMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) {
Entry* entry = GetTypeEntry(type_index);
if (entry->extra_info != nullptr) {
if (entry->metadata != nullptr) {
TVM_FFI_LOG_AND_THROW(RuntimeError)
<< "Overriding " << ToStringView(entry->type_key) << ", possible causes:\n"
<< "- two ObjectDef<T>() calls for the same T \n"
<< "- when we forget to assign _type_key to ObjectRef<Y> that inherits from T\n"
<< "- another type with the same key is already registered\n"
<< "Cross check the reflection registration.";
}
entry->extra_info_data = *extra_info;
entry->extra_info_data.doc = this->CopyString(extra_info->doc);
entry->extra_info = &(entry->extra_info_data);
entry->metadata_data = *metadata;
entry->metadata_data.doc = this->CopyString(metadata->doc);
entry->metadata = &(entry->metadata_data);
}

void RegisterTypeAttr(int32_t type_index, const TVMFFIByteArray* name, const TVMFFIAny* value) {
AnyView value_view = AnyView::CopyFromTVMFFIAny(*value);
String name_str(*name);
size_t column_index = 0;
auto it = type_attr_name_to_column_index_.find(name_str);
if (it == type_attr_name_to_column_index_.end()) {
column_index = type_attr_columns_.size();
type_attr_columns_.emplace_back(std::make_unique<TypeAttrColumnData>());
type_attr_name_to_column_index_.Set(name_str, column_index);
}
TypeAttrColumnData* column = type_attr_columns_[column_index].get();
if (column->data_.size() < static_cast<size_t>(type_index + 1)) {
column->data_.resize(type_index + 1, Any(nullptr));
column->data = reinterpret_cast<const TVMFFIAny*>(column->data_.data());
column->size = column->data_.size();
}
if (type_index == kTVMFFINone) return;
if (column->data_[type_index] != nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type attribute `" << name_str << "` is already set for type `"
<< TypeIndexToTypeKey(type_index) << "`";
}
column->data_[type_index] = value_view;
}
const TVMFFITypeAttrColumn* GetTypeAttrColumn(const TVMFFIByteArray* name) {
String name_str(*name);
auto it = type_attr_name_to_column_index_.find(name_str);
if (it == type_attr_name_to_column_index_.end()) return nullptr;
return type_attr_columns_[(*it).second].get();
}

void Dump(int min_children_count) {
Expand Down Expand Up @@ -284,11 +318,11 @@ class TypeTable {
this->GetOrAllocTypeIndex(Object::_type_key, Object::_type_index, Object::_type_depth,
Object::_type_child_slots, Object::_type_child_slots_can_overflow,
-1);
TVMFFITypeExtraInfo info;
TVMFFITypeMetadata info;
info.total_size = sizeof(Object);
info.creator = nullptr;
info.doc = TVMFFIByteArray{nullptr, 0};
RegisterTypeExtraInfo(Object::_type_index, &info);
RegisterTypeMetadata(Object::_type_index, &info);
// reserve the static types
ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFINone, TypeIndex::kTVMFFINone);
ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIInt, TypeIndex::kTVMFFIInt);
Expand Down Expand Up @@ -328,6 +362,9 @@ class TypeTable {
std::vector<std::unique_ptr<Entry>> type_table_;
Map<String, int64_t> type_key2index_;
std::vector<Any> any_pool_;
// type attribute columns
std::vector<std::unique_ptr<TypeAttrColumnData>> type_attr_columns_;
Map<String, int64_t> type_attr_name_to_column_index_;
};

void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
Expand All @@ -343,12 +380,12 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
TVM_FFI_ICHECK(args.size() % 2 == 1);
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);

if (type_info->extra_info == nullptr || type_info->extra_info->creator == nullptr) {
if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
<< "` does not support reflection creation";
}
TVMFFIObjectHandle handle;
TVM_FFI_CHECK_SAFE_CALL(type_info->extra_info->creator(&handle));
TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
ObjectPtr<Object> ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));

Expand Down Expand Up @@ -437,12 +474,25 @@ int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info) {
TVM_FFI_SAFE_CALL_END();
}

int TVMFFITypeRegisterExtraInfo(int32_t type_index, const TVMFFITypeExtraInfo* extra_info) {
int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::TypeTable::Global()->RegisterTypeExtraInfo(type_index, extra_info);
tvm::ffi::TypeTable::Global()->RegisterTypeMetadata(type_index, metadata);
TVM_FFI_SAFE_CALL_END();
}

int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* name,
const TVMFFIAny* value) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::TypeTable::Global()->RegisterTypeAttr(type_index, name, value);
TVM_FFI_SAFE_CALL_END();
}

const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* name) {
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
return tvm::ffi::TypeTable::Global()->GetTypeAttrColumn(name);
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeAttrColumn);
}

int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, int32_t static_type_index,
int32_t type_depth, int32_t num_child_slots,
int32_t child_slots_can_overflow, int32_t parent_type_index) {
Expand Down
4 changes: 2 additions & 2 deletions ffi/src/ffi/reflection/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ class StructEqualHandler {
bool CompareObject(ObjectRef lhs, ObjectRef rhs) {
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index());
if (type_info->extra_info == nullptr) {
if (type_info->metadata == nullptr) {
return lhs.same_as(rhs);
}
auto structural_eq_hash_kind = type_info->extra_info->structural_eq_hash_kind;
auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind;

if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported ||
structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) {
Expand Down
4 changes: 2 additions & 2 deletions ffi/src/ffi/reflection/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ class StructuralHashHandler {
uint64_t HashObject(ObjectRef obj) {
// NOTE: invariant: lhs and rhs are already the same type
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index());
if (type_info->extra_info == nullptr) {
if (type_info->metadata == nullptr) {
// Fallback to pointer hash
return std::hash<const Object*>()(obj.get());
}
auto structural_eq_hash_kind = type_info->extra_info->structural_eq_hash_kind;
auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind;
if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) {
// Fallback to pointer hash
return std::hash<const Object*>()(obj.get());
Expand Down
5 changes: 5 additions & 0 deletions ffi/tests/cpp/test_reflection_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ TEST(Reflection, ForEachFieldInfo) {
EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject));
}

TEST(Reflection, TypeAttrColumn) {
reflection::TypeAttrColumn size_attr("test.size");
EXPECT_EQ(size_attr[TIntObj::_type_index].cast<int>(), sizeof(TIntObj));
}

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_method("testing.Int_GetValue", &TIntObj::GetValue);
Expand Down
4 changes: 4 additions & 0 deletions ffi/tests/cpp/testing_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ inline void TIntObj::RegisterReflection() {
refl::ObjectDef<TIntObj>()
.def_ro("value", &TIntObj::value)
.def_static("static_add", &TInt::StaticAdd, "static add method");
// define extra type attributes
refl::TypeAttrDef<TIntObj>()
.def("test.GetValue", &TIntObj::GetValue)
.attr("test.size", sizeof(TIntObj));
}

class TFloatObj : public TNumberObj {
Expand Down
Loading
Loading