Skip to content

Commit 2bcad22

Browse files
committed
[FFI][REFACTOR] Introduce TypeAttr in reflection
This PR introduces TypeAttr to reflection to bring extra optional attribute registration that can be used to extend behaviors such as structural equality. Also renames TypeExtraInfo to TypeMetadata for better clarity.
1 parent 5aa4dfd commit 2bcad22

File tree

15 files changed

+213
-38
lines changed

15 files changed

+213
-38
lines changed

ffi/include/tvm/ffi/c_api.h

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,22 @@ typedef struct {
534534
* \brief Optional meta-data for structural eq/hash.
535535
*/
536536
TVMFFISEqHashKind structural_eq_hash_kind;
537-
} TVMFFITypeExtraInfo;
537+
} TVMFFITypeMetadata;
538+
539+
/*
540+
* \brief Column array that stores extra attributes about types
541+
*
542+
* The attributes stored in column arrays that can be looked up by type index.
543+
*
544+
* \note
545+
* \sa TVMFFIRegisterTypeAttr
546+
*/
547+
typedef struct {
548+
/*! \brief The data of the column. */
549+
const TVMFFIAny* data;
550+
/*! \brief The size of the column. */
551+
size_t size;
552+
} TVMFFITypeAttrColumn;
538553

539554
/*!
540555
* \brief Runtime type information for object type checking.
@@ -567,7 +582,7 @@ typedef struct TVMFFITypeInfo {
567582
/*! \brief The reflection method. */
568583
const TVMFFIMethodInfo* methods;
569584
/*! \brief The extra information of the type. */
570-
const TVMFFITypeExtraInfo* extra_info;
585+
const TVMFFITypeMetadata* metadata;
571586
} TVMFFITypeInfo;
572587

573588
//------------------------------------------------------------
@@ -738,11 +753,27 @@ TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodI
738753
/*!
739754
* \brief Register type creator information for runtime reflection.
740755
* \param type_index The type index
741-
* \param extra_info The extra information to be registered.
756+
* \param metadata The extra information to be registered.
742757
* \return 0 when success, nonzero when failure happens
743758
*/
744-
TVM_FFI_DLL int TVMFFITypeRegisterExtraInfo(int32_t type_index,
745-
const TVMFFITypeExtraInfo* extra_info);
759+
TVM_FFI_DLL int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata);
760+
761+
/*!
762+
* \brief Register extra type attributes that can be looked up during runtime.
763+
* \param type_index The type index
764+
* \param attr_value The attribute value to be registered.
765+
* \return 0 when success, nonzero when failure happens
766+
*/
767+
TVM_FFI_DLL int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* attr_name,
768+
const TVMFFIAny* attr_value);
769+
770+
/*!
771+
* \brief Get the type attribute column by name.
772+
* \param attr_name The name of the attribute.
773+
* \return The pointer to the type attribute column.
774+
* \return NULL if the attribute was not registered in the system
775+
*/
776+
TVM_FFI_DLL const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* attr_name);
746777

747778
//------------------------------------------------------------
748779
// Section: DLPack support APIs

ffi/include/tvm/ffi/reflection/accessor.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,29 @@ class FieldSetter {
104104
const TVMFFIFieldInfo* field_info_;
105105
};
106106

107+
class TypeAttrColumn {
108+
public:
109+
TypeAttrColumn(std::string_view attr_name) {
110+
TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()};
111+
column_ = TVMFFIGetTypeAttrColumn(&attr_name_array);
112+
if (column_ == nullptr) {
113+
TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name;
114+
}
115+
}
116+
117+
AnyView operator[](int32_t type_index) const {
118+
size_t tindex = static_cast<size_t>(type_index);
119+
if (tindex >= column_->size) {
120+
return AnyView();
121+
}
122+
const AnyView* any_view_data = reinterpret_cast<const AnyView*>(column_->data);
123+
return any_view_data[tindex];
124+
}
125+
126+
private:
127+
const TVMFFITypeAttrColumn* column_;
128+
};
129+
107130
/*!
108131
* \brief helper function to get reflection method info by type key and method name
109132
*

ffi/include/tvm/ffi/reflection/registry.h

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class ReflectionDefBase {
150150
}
151151

152152
template <typename T>
153-
TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeExtraInfo* info, const T& value) {
153+
TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata* info, const T& value) {
154154
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
155155
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
156156
}
@@ -381,7 +381,7 @@ class ObjectDef : public ReflectionDefBase {
381381
private:
382382
template <typename... ExtraArgs>
383383
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
384-
TVMFFITypeExtraInfo info;
384+
TVMFFITypeMetadata info;
385385
info.total_size = sizeof(Class);
386386
info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
387387
info.creator = nullptr;
@@ -391,7 +391,7 @@ class ObjectDef : public ReflectionDefBase {
391391
}
392392
// apply extra info traits
393393
((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
394-
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info));
394+
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info));
395395
}
396396

397397
template <typename T, typename BaseClass, typename... ExtraArgs>
@@ -446,6 +446,68 @@ class ObjectDef : public ReflectionDefBase {
446446
const char* type_key_;
447447
};
448448

449+
template <typename Class, typename = std::enable_if_t<std::is_base_of_v<Object, Class>>>
450+
class TypeAttrDef : public ReflectionDefBase {
451+
public:
452+
template <typename... ExtraArgs>
453+
explicit TypeAttrDef(ExtraArgs&&... extra_args)
454+
: type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {}
455+
456+
/*
457+
* \brief Define a function-valued type attribute.
458+
*
459+
* \tparam Func The function type.
460+
*
461+
* \param name The name of the function.
462+
* \param func The function to be registered.
463+
*
464+
* \return The TypeAttrDef object.
465+
*/
466+
template <typename Func>
467+
TypeAttrDef& def(const char* name, Func&& func) {
468+
TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
469+
ffi::Function ffi_func =
470+
GetMethod<Class>(std::string(type_key_) + "." + name, std::forward<Func>(func));
471+
TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny();
472+
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
473+
return *this;
474+
}
475+
476+
/*
477+
* \brief Define a constant-valued type attribute.
478+
*
479+
* \tparam T The type of the value.
480+
*
481+
* \param name The name of the attribute.
482+
* \param value The value of the attribute.
483+
*
484+
* \return The TypeAttrDef object.
485+
*/
486+
template <typename T>
487+
TypeAttrDef& attr(const char* name, T value) {
488+
TVMFFIByteArray name_array = {name, std::char_traits<char>::length(name)};
489+
TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny();
490+
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any));
491+
return *this;
492+
}
493+
494+
private:
495+
int32_t type_index_;
496+
const char* type_key_;
497+
};
498+
499+
/*!
500+
* \brief Ensure the type attribute column is presented in the system.
501+
*
502+
* \param name The name of the type attribute.
503+
*/
504+
inline void EnsureTypeAttrColumn(std::string_view name) {
505+
TVMFFIByteArray name_array = {name.data(), name.size()};
506+
AnyView any_view(nullptr);
507+
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array,
508+
reinterpret_cast<const TVMFFIAny*>(&any_view)));
509+
}
510+
449511
} // namespace reflection
450512
} // namespace ffi
451513
} // namespace tvm

ffi/src/ffi/object.cc

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class TypeTable {
5959
/*! \brief type methods informaton */
6060
std::vector<TVMFFIMethodInfo> type_methods_data;
6161
/*! \brief extra information */
62-
TVMFFITypeExtraInfo extra_info_data;
62+
TVMFFITypeMetadata metadata_data;
6363
// NOTE: the indices in [index, index + num_reserved_slots) are
6464
// reserved for the child-class of this type.
6565
/*! \brief Total number of slots reserved for the type and its children. */
@@ -100,10 +100,14 @@ class TypeTable {
100100
this->num_methods = 0;
101101
this->fields = nullptr;
102102
this->methods = nullptr;
103-
this->extra_info = nullptr;
103+
this->metadata = nullptr;
104104
}
105105
};
106106

107+
struct TypeAttrColumnData : public TVMFFITypeAttrColumn {
108+
std::vector<Any> data_;
109+
};
110+
107111
int32_t GetOrAllocTypeIndex(String type_key, int32_t static_type_index, int32_t type_depth,
108112
int32_t num_child_slots, bool child_slots_can_overflow,
109113
int32_t parent_type_index) {
@@ -219,19 +223,49 @@ class TypeTable {
219223
entry->num_methods = static_cast<int32_t>(entry->type_methods_data.size());
220224
}
221225

222-
void RegisterTypeExtraInfo(int32_t type_index, const TVMFFITypeExtraInfo* extra_info) {
226+
void RegisterTypeMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) {
223227
Entry* entry = GetTypeEntry(type_index);
224-
if (entry->extra_info != nullptr) {
228+
if (entry->metadata != nullptr) {
225229
TVM_FFI_LOG_AND_THROW(RuntimeError)
226230
<< "Overriding " << ToStringView(entry->type_key) << ", possible causes:\n"
227231
<< "- two ObjectDef<T>() calls for the same T \n"
228232
<< "- when we forget to assign _type_key to ObjectRef<Y> that inherits from T\n"
229233
<< "- another type with the same key is already registered\n"
230234
<< "Cross check the reflection registration.";
231235
}
232-
entry->extra_info_data = *extra_info;
233-
entry->extra_info_data.doc = this->CopyString(extra_info->doc);
234-
entry->extra_info = &(entry->extra_info_data);
236+
entry->metadata_data = *metadata;
237+
entry->metadata_data.doc = this->CopyString(metadata->doc);
238+
entry->metadata = &(entry->metadata_data);
239+
}
240+
241+
void RegisterTypeAttr(int32_t type_index, const TVMFFIByteArray* name, const TVMFFIAny* value) {
242+
AnyView value_view = AnyView::CopyFromTVMFFIAny(*value);
243+
String name_str(*name);
244+
size_t column_index = 0;
245+
auto it = type_attr_name_to_column_index_.find(name_str);
246+
if (it == type_attr_name_to_column_index_.end()) {
247+
column_index = type_attr_columns_.size();
248+
type_attr_columns_.emplace_back(std::make_unique<TypeAttrColumnData>());
249+
type_attr_name_to_column_index_.Set(name_str, column_index);
250+
}
251+
TypeAttrColumnData* column = type_attr_columns_[column_index].get();
252+
if (column->data_.size() < static_cast<size_t>(type_index + 1)) {
253+
column->data_.resize(type_index + 1, Any(nullptr));
254+
column->data = reinterpret_cast<const TVMFFIAny*>(column->data_.data());
255+
column->size = column->data_.size();
256+
}
257+
if (type_index == kTVMFFINone) return;
258+
if (column->data_[type_index] != nullptr) {
259+
TVM_FFI_THROW(RuntimeError) << "Type attribute `" << name_str << "` is already set for type `"
260+
<< TypeIndexToTypeKey(type_index) << "`";
261+
}
262+
column->data_[type_index] = value_view;
263+
}
264+
const TVMFFITypeAttrColumn* GetTypeAttrColumn(const TVMFFIByteArray* name) {
265+
String name_str(*name);
266+
auto it = type_attr_name_to_column_index_.find(name_str);
267+
if (it == type_attr_name_to_column_index_.end()) return nullptr;
268+
return type_attr_columns_[(*it).second].get();
235269
}
236270

237271
void Dump(int min_children_count) {
@@ -284,11 +318,11 @@ class TypeTable {
284318
this->GetOrAllocTypeIndex(Object::_type_key, Object::_type_index, Object::_type_depth,
285319
Object::_type_child_slots, Object::_type_child_slots_can_overflow,
286320
-1);
287-
TVMFFITypeExtraInfo info;
321+
TVMFFITypeMetadata info;
288322
info.total_size = sizeof(Object);
289323
info.creator = nullptr;
290324
info.doc = TVMFFIByteArray{nullptr, 0};
291-
RegisterTypeExtraInfo(Object::_type_index, &info);
325+
RegisterTypeMetadata(Object::_type_index, &info);
292326
// reserve the static types
293327
ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFINone, TypeIndex::kTVMFFINone);
294328
ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIInt, TypeIndex::kTVMFFIInt);
@@ -328,6 +362,9 @@ class TypeTable {
328362
std::vector<std::unique_ptr<Entry>> type_table_;
329363
Map<String, int64_t> type_key2index_;
330364
std::vector<Any> any_pool_;
365+
// type attribute columns
366+
std::vector<std::unique_ptr<TypeAttrColumnData>> type_attr_columns_;
367+
Map<String, int64_t> type_attr_name_to_column_index_;
331368
};
332369

333370
void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
@@ -343,12 +380,12 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
343380
TVM_FFI_ICHECK(args.size() % 2 == 1);
344381
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
345382

346-
if (type_info->extra_info == nullptr || type_info->extra_info->creator == nullptr) {
383+
if (type_info->metadata == nullptr || type_info->metadata->creator == nullptr) {
347384
TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index)
348385
<< "` does not support reflection creation";
349386
}
350387
TVMFFIObjectHandle handle;
351-
TVM_FFI_CHECK_SAFE_CALL(type_info->extra_info->creator(&handle));
388+
TVM_FFI_CHECK_SAFE_CALL(type_info->metadata->creator(&handle));
352389
ObjectPtr<Object> ptr =
353390
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
354391

@@ -437,12 +474,25 @@ int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info) {
437474
TVM_FFI_SAFE_CALL_END();
438475
}
439476

440-
int TVMFFITypeRegisterExtraInfo(int32_t type_index, const TVMFFITypeExtraInfo* extra_info) {
477+
int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata) {
441478
TVM_FFI_SAFE_CALL_BEGIN();
442-
tvm::ffi::TypeTable::Global()->RegisterTypeExtraInfo(type_index, extra_info);
479+
tvm::ffi::TypeTable::Global()->RegisterTypeMetadata(type_index, metadata);
443480
TVM_FFI_SAFE_CALL_END();
444481
}
445482

483+
int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* name,
484+
const TVMFFIAny* value) {
485+
TVM_FFI_SAFE_CALL_BEGIN();
486+
tvm::ffi::TypeTable::Global()->RegisterTypeAttr(type_index, name, value);
487+
TVM_FFI_SAFE_CALL_END();
488+
}
489+
490+
const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* name) {
491+
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
492+
return tvm::ffi::TypeTable::Global()->GetTypeAttrColumn(name);
493+
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeAttrColumn);
494+
}
495+
446496
int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, int32_t static_type_index,
447497
int32_t type_depth, int32_t num_child_slots,
448498
int32_t child_slots_can_overflow, int32_t parent_type_index) {

ffi/src/ffi/reflection/structural_equal.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ class StructEqualHandler {
8989
bool CompareObject(ObjectRef lhs, ObjectRef rhs) {
9090
// NOTE: invariant: lhs and rhs are already the same type
9191
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index());
92-
if (type_info->extra_info == nullptr) {
92+
if (type_info->metadata == nullptr) {
9393
return lhs.same_as(rhs);
9494
}
95-
auto structural_eq_hash_kind = type_info->extra_info->structural_eq_hash_kind;
95+
auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind;
9696

9797
if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported ||
9898
structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) {

ffi/src/ffi/reflection/structural_hash.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ class StructuralHashHandler {
8282
uint64_t HashObject(ObjectRef obj) {
8383
// NOTE: invariant: lhs and rhs are already the same type
8484
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index());
85-
if (type_info->extra_info == nullptr) {
85+
if (type_info->metadata == nullptr) {
8686
// Fallback to pointer hash
8787
return std::hash<const Object*>()(obj.get());
8888
}
89-
auto structural_eq_hash_kind = type_info->extra_info->structural_eq_hash_kind;
89+
auto structural_eq_hash_kind = type_info->metadata->structural_eq_hash_kind;
9090
if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) {
9191
// Fallback to pointer hash
9292
return std::hash<const Object*>()(obj.get());

ffi/tests/cpp/test_reflection_accessor.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ TEST(Reflection, ForEachFieldInfo) {
143143
EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject));
144144
}
145145

146+
TEST(Reflection, TypeAttrColumn) {
147+
reflection::TypeAttrColumn size_attr("test.size");
148+
EXPECT_EQ(size_attr[TIntObj::_type_index].cast<int>(), sizeof(TIntObj));
149+
}
150+
146151
TVM_FFI_STATIC_INIT_BLOCK({
147152
namespace refl = tvm::ffi::reflection;
148153
refl::GlobalDef().def_method("testing.Int_GetValue", &TIntObj::GetValue);

ffi/tests/cpp/testing_object.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ inline void TIntObj::RegisterReflection() {
8383
refl::ObjectDef<TIntObj>()
8484
.def_ro("value", &TIntObj::value)
8585
.def_static("static_add", &TInt::StaticAdd, "static add method");
86+
// define extra type attributes
87+
refl::TypeAttrDef<TIntObj>()
88+
.def("test.GetValue", &TIntObj::GetValue)
89+
.attr("test.size", sizeof(TIntObj));
8690
}
8791

8892
class TFloatObj : public TNumberObj {

0 commit comments

Comments
 (0)