Skip to content

Commit dd51f6f

Browse files
committed
[FFI][REFACTOR] Refactor AccessPath to enable full tree repr
This PR refactors AccessPath so it can be used to represent full tree with compact memory. Also fixes a bug in thec cython method export
1 parent ba994d3 commit dd51f6f

File tree

16 files changed

+899
-208
lines changed

16 files changed

+899
-208
lines changed

ffi/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ set(tvm_ffi_objs_sources
5959
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
6060
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
6161
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
62-
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
6362
)
6463

6564
if (TVM_FFI_USE_EXTRA_CXX_API)
@@ -69,6 +68,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API)
6968
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_parser.cc"
7069
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/json_writer.cc"
7170
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/serialization.cc"
71+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/reflection_extra.cc"
7272
)
7373
endif()
7474

ffi/include/tvm/ffi/c_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index);
896896
#endif
897897

898898
//---------------------------------------------------------------
899-
// The following API defines static object field accessors
899+
// The following API defines static object attribute accessors
900900
// for language bindings.
901901
//
902902
// They are defined in C++ inline functions for cleaner code.

ffi/include/tvm/ffi/reflection/access_path.h

Lines changed: 281 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,29 @@
2525

2626
#include <tvm/ffi/any.h>
2727
#include <tvm/ffi/c_api.h>
28+
#include <tvm/ffi/cast.h>
2829
#include <tvm/ffi/container/array.h>
2930
#include <tvm/ffi/container/tuple.h>
31+
#include <tvm/ffi/error.h>
3032
#include <tvm/ffi/reflection/registry.h>
3133

3234
namespace tvm {
3335
namespace ffi {
3436
namespace reflection {
3537

3638
enum class AccessKind : int32_t {
37-
kObjectField = 0,
39+
kAttr = 0,
3840
kArrayItem = 1,
3941
kMapItem = 2,
4042
// the following two are used for error reporting when
4143
// the supposed access field is not available
42-
kArrayItemMissing = 3,
43-
kMapItemMissing = 4,
44+
kAttrMissing = 3,
45+
kArrayItemMissing = 4,
46+
kMapItemMissing = 5,
4447
};
4548

49+
class AccessStep;
50+
4651
/*!
4752
* \brief Represent a single step in object field, map key, array index access.
4853
*/
@@ -59,16 +64,18 @@ class AccessStepObj : public Object {
5964
*/
6065
Any key;
6166

67+
// default constructor to enable auto-serialization
68+
AccessStepObj() = default;
6269
AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {}
6370

64-
static void RegisterReflection() {
65-
namespace refl = tvm::ffi::reflection;
66-
refl::ObjectDef<AccessStepObj>()
67-
.def_ro("kind", &AccessStepObj::kind)
68-
.def_ro("key", &AccessStepObj::key);
69-
}
71+
/*!
72+
* \brief Deep check if two steps are equal.
73+
* \param other The other step to compare with.
74+
* \return True if the two steps are equal, false otherwise.
75+
*/
76+
inline bool StepEqual(const AccessStep& other) const;
7077

71-
static constexpr const char* _type_key = "tvm.ffi.reflection.AccessStep";
78+
static constexpr const char* _type_key = "ffi.reflection.AccessStep";
7279
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
7380
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessStepObj, Object);
7481
};
@@ -82,8 +89,10 @@ class AccessStep : public ObjectRef {
8289
public:
8390
AccessStep(AccessKind kind, Any key) : ObjectRef(make_object<AccessStepObj>(kind, key)) {}
8491

85-
static AccessStep ObjectField(String field_name) {
86-
return AccessStep(AccessKind::kObjectField, field_name);
92+
static AccessStep Attr(String field_name) { return AccessStep(AccessKind::kAttr, field_name); }
93+
94+
static AccessStep AttrMissing(String field_name) {
95+
return AccessStep(AccessKind::kAttrMissing, field_name);
8796
}
8897

8998
static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); }
@@ -94,15 +103,273 @@ class AccessStep : public ObjectRef {
94103

95104
static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, key); }
96105

97-
static AccessStep MapItemMissing(Any key) { return AccessStep(AccessKind::kMapItemMissing, key); }
106+
static AccessStep MapItemMissing(Any key = nullptr) {
107+
return AccessStep(AccessKind::kMapItemMissing, key);
108+
}
98109

99110
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj);
100111
};
101112

102-
using AccessPath = Array<AccessStep>;
113+
inline bool AccessStepObj::StepEqual(const AccessStep& other) const {
114+
return this->kind == other->kind && AnyEqual()(this->key, other->key);
115+
}
116+
117+
// forward declaration
118+
class AccessPath;
119+
120+
/*!
121+
* \brief ObjectRef class of AccessPathObj.
122+
*
123+
* \sa AccessPathObj
124+
*/
125+
class AccessPathObj : public Object {
126+
public:
127+
/*!
128+
* \brief The parent of the access path.
129+
*
130+
* This parent-pointing tree structure is more space efficient when
131+
* representing multiple paths that share a common prefix.
132+
*
133+
* \note Empty for root.
134+
*/
135+
Optional<ObjectRef> parent;
136+
/*!
137+
* \brief The current of the access path.
138+
* \note Empty for root.
139+
*/
140+
Optional<AccessStep> step;
141+
/*!
142+
* \brief The current depth of the access path, 0 for root
143+
*/
144+
int32_t depth;
145+
146+
// default constructor to enable auto-serialization
147+
AccessPathObj() = default;
148+
/*!
149+
* \brief Constructor for the access path.
150+
* \param parent The parent of the access path.
151+
* \param step The current step of the access path.
152+
* \param depth The current depth of the access path.
153+
*/
154+
AccessPathObj(Optional<ObjectRef> parent, Optional<AccessStep> step, int32_t depth)
155+
: parent(parent), step(step), depth(depth) {}
156+
157+
/*!
158+
* \brief Get the parent of the access path.
159+
* \return The parent of the access path.
160+
*/
161+
inline Optional<AccessPath> GetParent() const;
162+
163+
/*!
164+
* \brief Extend the access path with a new step.
165+
* \param step The step to extend the access path with.
166+
* \return The extended access path.
167+
*/
168+
inline AccessPath Extend(AccessStep step) const;
169+
170+
/*!
171+
* \brief Extend the access path with an object attribute access.
172+
* \param field_name The name of the field to access.
173+
* \return The extended access path.
174+
*/
175+
inline AccessPath Attr(String field_name) const;
176+
177+
/*!
178+
* \brief Extend the access path with an object attribute missing access.
179+
* \param field_name The name of the field to access.
180+
* \return The extended access path.
181+
*/
182+
inline AccessPath AttrMissing(String field_name) const;
183+
184+
/*!
185+
* \brief Extend the access path with an array item access.
186+
* \param index The index of the array item to access.
187+
* \return The extended access path.
188+
*/
189+
inline AccessPath ArrayItem(int64_t index) const;
190+
191+
/*!
192+
* \brief Extend the access path with an array item missing access.
193+
* \param index The index of the array item to access.
194+
* \return The extended access path.
195+
*/
196+
inline AccessPath ArrayItemMissing(int64_t index) const;
197+
198+
/*!
199+
* \brief Extend the access path with a map item access.
200+
* \param key The key of the map item to access.
201+
* \return The extended access path.
202+
*/
203+
inline AccessPath MapItem(Any key) const;
204+
205+
/*!
206+
* \brief Extend the access path with a map item missing access.
207+
* \param key The key of the map item to access.
208+
* \return The extended access path.
209+
*/
210+
inline AccessPath MapItemMissing(Any key) const;
211+
212+
/*!
213+
* \brief Get the array of steps that corresponds to the access path.
214+
* \return The array of steps that corresponds to the access path.
215+
*/
216+
inline Array<AccessStep> ToSteps() const;
217+
218+
/*!
219+
* \brief Check if two paths are equal by deep comparing the steps.
220+
* \param other The other path to compare with.
221+
* \return True if the two paths are equal, false otherwise.
222+
*/
223+
inline bool PathEqual(const AccessPath& other) const;
224+
225+
/*!
226+
* \brief Check if this path is a prefix of another path.
227+
* \param other The other path to compare with.
228+
* \return True if this path is a prefix of the other path, false otherwise.
229+
*/
230+
inline bool IsPrefixOf(const AccessPath& other) const;
231+
232+
static constexpr const char* _type_key = "ffi.reflection.AccessPath";
233+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode;
234+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessPathObj, Object);
235+
236+
private:
237+
static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) {
238+
// fast path for same pointer
239+
if (lhs == rhs) return true;
240+
if (lhs->depth != rhs->depth) return false;
241+
// do deep equality checks
242+
while (lhs->parent.has_value()) {
243+
TVM_FFI_ICHECK(rhs->parent.has_value());
244+
TVM_FFI_ICHECK(lhs->step.has_value());
245+
TVM_FFI_ICHECK(rhs->step.has_value());
246+
if (!(*lhs->step)->StepEqual(*(rhs->step))) {
247+
return false;
248+
}
249+
lhs = static_cast<const AccessPathObj*>(lhs->parent.get());
250+
rhs = static_cast<const AccessPathObj*>(rhs->parent.get());
251+
// fast path for same pointer
252+
if (lhs == rhs) return true;
253+
TVM_FFI_ICHECK(lhs != nullptr);
254+
TVM_FFI_ICHECK(rhs != nullptr);
255+
}
256+
return true;
257+
}
258+
};
259+
260+
/*!
261+
* \brief ObjectRef class of AccessPath.
262+
*
263+
* \sa AccessPathObj
264+
*/
265+
class AccessPath : public ObjectRef {
266+
public:
267+
/*!
268+
* \brief Create an access path from an iterator range of steps.
269+
* \param begin The beginning of the iterator range.
270+
* \param end The end of the iterator range.
271+
* \return The access path.
272+
*/
273+
template <typename Iter>
274+
static AccessPath FromSteps(Iter begin, Iter end) {
275+
AccessPath path = AccessPath::Root();
276+
for (Iter it = begin; it != end; ++it) {
277+
path = path->Extend(*it);
278+
}
279+
return path;
280+
}
281+
/*!
282+
* \brief Create an access path from an array of steps.
283+
* \param steps The array of steps.
284+
* \return The access path.
285+
*/
286+
static AccessPath FromSteps(Array<AccessStep> steps) {
287+
AccessPath path = AccessPath::Root();
288+
for (AccessStep step : steps) {
289+
path = path->Extend(step);
290+
}
291+
return path;
292+
}
293+
294+
/*!
295+
* \brief Create a root access path.
296+
* \return The root access path.
297+
*/
298+
static AccessPath Root() {
299+
return AccessPath(make_object<AccessPathObj>(std::nullopt, std::nullopt, 0));
300+
}
301+
302+
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessPath, ObjectRef, AccessPathObj);
303+
};
304+
103305
using AccessPathPair = Tuple<AccessPath, AccessPath>;
104306

307+
inline Optional<AccessPath> AccessPathObj::GetParent() const {
308+
if (auto opt_parent = this->parent.as<AccessPath>()) {
309+
return opt_parent;
310+
}
311+
return std::nullopt;
312+
}
313+
314+
inline AccessPath AccessPathObj::Extend(AccessStep step) const {
315+
return AccessPath(make_object<AccessPathObj>(GetRef<AccessPath>(this), step, this->depth + 1));
316+
}
317+
318+
inline AccessPath AccessPathObj::Attr(String field_name) const {
319+
return this->Extend(AccessStep::Attr(field_name));
320+
}
321+
322+
inline AccessPath AccessPathObj::AttrMissing(String field_name) const {
323+
return this->Extend(AccessStep::AttrMissing(field_name));
324+
}
325+
326+
inline AccessPath AccessPathObj::ArrayItem(int64_t index) const {
327+
return this->Extend(AccessStep::ArrayItem(index));
328+
}
329+
330+
inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const {
331+
return this->Extend(AccessStep::ArrayItemMissing(index));
332+
}
333+
334+
inline AccessPath AccessPathObj::MapItem(Any key) const {
335+
return this->Extend(AccessStep::MapItem(key));
336+
}
337+
338+
inline AccessPath AccessPathObj::MapItemMissing(Any key) const {
339+
return this->Extend(AccessStep::MapItemMissing(key));
340+
}
341+
342+
inline Array<AccessStep> AccessPathObj::ToSteps() const {
343+
std::vector<AccessStep> reverse_steps;
344+
reverse_steps.reserve(this->depth);
345+
const AccessPathObj* current = this;
346+
while (current->parent.has_value()) {
347+
TVM_FFI_ICHECK(current->step.has_value());
348+
reverse_steps.push_back(*(current->step));
349+
current = static_cast<const AccessPathObj*>(current->parent.get());
350+
TVM_FFI_ICHECK(current != nullptr);
351+
}
352+
return Array<AccessStep>(reverse_steps.rbegin(), reverse_steps.rend());
353+
}
354+
355+
inline bool AccessPathObj::PathEqual(const AccessPath& other) const {
356+
return PathEqual(this, other.get());
357+
}
358+
359+
inline bool AccessPathObj::IsPrefixOf(const AccessPath& other) const {
360+
if (this->depth > other->depth) {
361+
return false;
362+
}
363+
const AccessPathObj* rhs_path = other.get();
364+
while (rhs_path->depth > this->depth) {
365+
TVM_FFI_ICHECK(rhs_path->parent.has_value());
366+
rhs_path = static_cast<const AccessPathObj*>(rhs_path->parent.get());
367+
}
368+
return PathEqual(this, rhs_path);
369+
}
370+
105371
} // namespace reflection
106372
} // namespace ffi
107373
} // namespace tvm
374+
108375
#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_

0 commit comments

Comments
 (0)