2525
2626#include < tvm/ffi/any.h>
2727#include < tvm/ffi/c_api.h>
28+ #include < tvm/ffi/cast.h>
2829#include < tvm/ffi/container/array.h>
2930#include < tvm/ffi/container/tuple.h>
31+ #include < tvm/ffi/error.h>
3032#include < tvm/ffi/reflection/registry.h>
3133
3234namespace tvm {
3335namespace ffi {
3436namespace reflection {
3537
3638enum class AccessKind : int32_t {
37- kObjectField = 0 ,
39+ kAttr = 0 ,
3840 kArrayItem = 1 ,
3941 kMapItem = 2 ,
4042 // the following two are used for error reporting when
4143 // the supposed access field is not available
42- kArrayItemMissing = 3 ,
43- kMapItemMissing = 4 ,
44+ kAttrMissing = 3 ,
45+ kArrayItemMissing = 4 ,
46+ kMapItemMissing = 5 ,
4447};
4548
49+ class AccessStep ;
50+
4651/* !
4752 * \brief Represent a single step in object field, map key, array index access.
4853 */
@@ -59,16 +64,18 @@ class AccessStepObj : public Object {
5964 */
6065 Any key;
6166
67+ // default constructor to enable auto-serialization
68+ AccessStepObj () = default ;
6269 AccessStepObj (AccessKind kind, Any key) : kind(kind), key(key) {}
6370
64- static void RegisterReflection () {
65- namespace refl = tvm::ffi::reflection;
66- refl::ObjectDef<AccessStepObj>()
67- . def_ro ( " kind " , &AccessStepObj::kind)
68- . def_ro ( " key " , &AccessStepObj::key);
69- }
71+ /* !
72+ * \brief Deep check if two steps are equal.
73+ * \param other The other step to compare with.
74+ * \return True if the two steps are equal, false otherwise.
75+ */
76+ inline bool StepEqual ( const AccessStep& other) const ;
7077
71- static constexpr const char * _type_key = " tvm. ffi.reflection.AccessStep" ;
78+ static constexpr const char * _type_key = " ffi.reflection.AccessStep" ;
7279 static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode ;
7380 TVM_FFI_DECLARE_FINAL_OBJECT_INFO (AccessStepObj, Object);
7481};
@@ -82,8 +89,10 @@ class AccessStep : public ObjectRef {
8289 public:
8390 AccessStep (AccessKind kind, Any key) : ObjectRef(make_object<AccessStepObj>(kind, key)) {}
8491
85- static AccessStep ObjectField (String field_name) {
86- return AccessStep (AccessKind::kObjectField , field_name);
92+ static AccessStep Attr (String field_name) { return AccessStep (AccessKind::kAttr , field_name); }
93+
94+ static AccessStep AttrMissing (String field_name) {
95+ return AccessStep (AccessKind::kAttrMissing , field_name);
8796 }
8897
8998 static AccessStep ArrayItem (int64_t index) { return AccessStep (AccessKind::kArrayItem , index); }
@@ -94,15 +103,273 @@ class AccessStep : public ObjectRef {
94103
95104 static AccessStep MapItem (Any key) { return AccessStep (AccessKind::kMapItem , key); }
96105
97- static AccessStep MapItemMissing (Any key) { return AccessStep (AccessKind::kMapItemMissing , key); }
106+ static AccessStep MapItemMissing (Any key = nullptr ) {
107+ return AccessStep (AccessKind::kMapItemMissing , key);
108+ }
98109
99110 TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS (AccessStep, ObjectRef, AccessStepObj);
100111};
101112
102- using AccessPath = Array<AccessStep>;
113+ inline bool AccessStepObj::StepEqual (const AccessStep& other) const {
114+ return this ->kind == other->kind && AnyEqual ()(this ->key , other->key );
115+ }
116+
117+ // forward declaration
118+ class AccessPath ;
119+
120+ /* !
121+ * \brief ObjectRef class of AccessPathObj.
122+ *
123+ * \sa AccessPathObj
124+ */
125+ class AccessPathObj : public Object {
126+ public:
127+ /* !
128+ * \brief The parent of the access path.
129+ *
130+ * This parent-pointing tree structure is more space efficient when
131+ * representing multiple paths that share a common prefix.
132+ *
133+ * \note Empty for root.
134+ */
135+ Optional<ObjectRef> parent;
136+ /* !
137+ * \brief The current of the access path.
138+ * \note Empty for root.
139+ */
140+ Optional<AccessStep> step;
141+ /* !
142+ * \brief The current depth of the access path, 0 for root
143+ */
144+ int32_t depth;
145+
146+ // default constructor to enable auto-serialization
147+ AccessPathObj () = default ;
148+ /* !
149+ * \brief Constructor for the access path.
150+ * \param parent The parent of the access path.
151+ * \param step The current step of the access path.
152+ * \param depth The current depth of the access path.
153+ */
154+ AccessPathObj (Optional<ObjectRef> parent, Optional<AccessStep> step, int32_t depth)
155+ : parent(parent), step(step), depth(depth) {}
156+
157+ /* !
158+ * \brief Get the parent of the access path.
159+ * \return The parent of the access path.
160+ */
161+ inline Optional<AccessPath> GetParent () const ;
162+
163+ /* !
164+ * \brief Extend the access path with a new step.
165+ * \param step The step to extend the access path with.
166+ * \return The extended access path.
167+ */
168+ inline AccessPath Extend (AccessStep step) const ;
169+
170+ /* !
171+ * \brief Extend the access path with an object attribute access.
172+ * \param field_name The name of the field to access.
173+ * \return The extended access path.
174+ */
175+ inline AccessPath Attr (String field_name) const ;
176+
177+ /* !
178+ * \brief Extend the access path with an object attribute missing access.
179+ * \param field_name The name of the field to access.
180+ * \return The extended access path.
181+ */
182+ inline AccessPath AttrMissing (String field_name) const ;
183+
184+ /* !
185+ * \brief Extend the access path with an array item access.
186+ * \param index The index of the array item to access.
187+ * \return The extended access path.
188+ */
189+ inline AccessPath ArrayItem (int64_t index) const ;
190+
191+ /* !
192+ * \brief Extend the access path with an array item missing access.
193+ * \param index The index of the array item to access.
194+ * \return The extended access path.
195+ */
196+ inline AccessPath ArrayItemMissing (int64_t index) const ;
197+
198+ /* !
199+ * \brief Extend the access path with a map item access.
200+ * \param key The key of the map item to access.
201+ * \return The extended access path.
202+ */
203+ inline AccessPath MapItem (Any key) const ;
204+
205+ /* !
206+ * \brief Extend the access path with a map item missing access.
207+ * \param key The key of the map item to access.
208+ * \return The extended access path.
209+ */
210+ inline AccessPath MapItemMissing (Any key) const ;
211+
212+ /* !
213+ * \brief Get the array of steps that corresponds to the access path.
214+ * \return The array of steps that corresponds to the access path.
215+ */
216+ inline Array<AccessStep> ToSteps () const ;
217+
218+ /* !
219+ * \brief Check if two paths are equal by deep comparing the steps.
220+ * \param other The other path to compare with.
221+ * \return True if the two paths are equal, false otherwise.
222+ */
223+ inline bool PathEqual (const AccessPath& other) const ;
224+
225+ /* !
226+ * \brief Check if this path is a prefix of another path.
227+ * \param other The other path to compare with.
228+ * \return True if this path is a prefix of the other path, false otherwise.
229+ */
230+ inline bool IsPrefixOf (const AccessPath& other) const ;
231+
232+ static constexpr const char * _type_key = " ffi.reflection.AccessPath" ;
233+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode ;
234+ TVM_FFI_DECLARE_FINAL_OBJECT_INFO (AccessPathObj, Object);
235+
236+ private:
237+ static bool PathEqual (const AccessPathObj* lhs, const AccessPathObj* rhs) {
238+ // fast path for same pointer
239+ if (lhs == rhs) return true ;
240+ if (lhs->depth != rhs->depth ) return false ;
241+ // do deep equality checks
242+ while (lhs->parent .has_value ()) {
243+ TVM_FFI_ICHECK (rhs->parent .has_value ());
244+ TVM_FFI_ICHECK (lhs->step .has_value ());
245+ TVM_FFI_ICHECK (rhs->step .has_value ());
246+ if (!(*lhs->step )->StepEqual (*(rhs->step ))) {
247+ return false ;
248+ }
249+ lhs = static_cast <const AccessPathObj*>(lhs->parent .get ());
250+ rhs = static_cast <const AccessPathObj*>(rhs->parent .get ());
251+ // fast path for same pointer
252+ if (lhs == rhs) return true ;
253+ TVM_FFI_ICHECK (lhs != nullptr );
254+ TVM_FFI_ICHECK (rhs != nullptr );
255+ }
256+ return true ;
257+ }
258+ };
259+
260+ /* !
261+ * \brief ObjectRef class of AccessPath.
262+ *
263+ * \sa AccessPathObj
264+ */
265+ class AccessPath : public ObjectRef {
266+ public:
267+ /* !
268+ * \brief Create an access path from an iterator range of steps.
269+ * \param begin The beginning of the iterator range.
270+ * \param end The end of the iterator range.
271+ * \return The access path.
272+ */
273+ template <typename Iter>
274+ static AccessPath FromSteps (Iter begin, Iter end) {
275+ AccessPath path = AccessPath::Root ();
276+ for (Iter it = begin; it != end; ++it) {
277+ path = path->Extend (*it);
278+ }
279+ return path;
280+ }
281+ /* !
282+ * \brief Create an access path from an array of steps.
283+ * \param steps The array of steps.
284+ * \return The access path.
285+ */
286+ static AccessPath FromSteps (Array<AccessStep> steps) {
287+ AccessPath path = AccessPath::Root ();
288+ for (AccessStep step : steps) {
289+ path = path->Extend (step);
290+ }
291+ return path;
292+ }
293+
294+ /* !
295+ * \brief Create a root access path.
296+ * \return The root access path.
297+ */
298+ static AccessPath Root () {
299+ return AccessPath (make_object<AccessPathObj>(std::nullopt , std::nullopt , 0 ));
300+ }
301+
302+ TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS (AccessPath, ObjectRef, AccessPathObj);
303+ };
304+
103305using AccessPathPair = Tuple<AccessPath, AccessPath>;
104306
307+ inline Optional<AccessPath> AccessPathObj::GetParent () const {
308+ if (auto opt_parent = this ->parent .as <AccessPath>()) {
309+ return opt_parent;
310+ }
311+ return std::nullopt ;
312+ }
313+
314+ inline AccessPath AccessPathObj::Extend (AccessStep step) const {
315+ return AccessPath (make_object<AccessPathObj>(GetRef<AccessPath>(this ), step, this ->depth + 1 ));
316+ }
317+
318+ inline AccessPath AccessPathObj::Attr (String field_name) const {
319+ return this ->Extend (AccessStep::Attr (field_name));
320+ }
321+
322+ inline AccessPath AccessPathObj::AttrMissing (String field_name) const {
323+ return this ->Extend (AccessStep::AttrMissing (field_name));
324+ }
325+
326+ inline AccessPath AccessPathObj::ArrayItem (int64_t index) const {
327+ return this ->Extend (AccessStep::ArrayItem (index));
328+ }
329+
330+ inline AccessPath AccessPathObj::ArrayItemMissing (int64_t index) const {
331+ return this ->Extend (AccessStep::ArrayItemMissing (index));
332+ }
333+
334+ inline AccessPath AccessPathObj::MapItem (Any key) const {
335+ return this ->Extend (AccessStep::MapItem (key));
336+ }
337+
338+ inline AccessPath AccessPathObj::MapItemMissing (Any key) const {
339+ return this ->Extend (AccessStep::MapItemMissing (key));
340+ }
341+
342+ inline Array<AccessStep> AccessPathObj::ToSteps () const {
343+ std::vector<AccessStep> reverse_steps;
344+ reverse_steps.reserve (this ->depth );
345+ const AccessPathObj* current = this ;
346+ while (current->parent .has_value ()) {
347+ TVM_FFI_ICHECK (current->step .has_value ());
348+ reverse_steps.push_back (*(current->step ));
349+ current = static_cast <const AccessPathObj*>(current->parent .get ());
350+ TVM_FFI_ICHECK (current != nullptr );
351+ }
352+ return Array<AccessStep>(reverse_steps.rbegin (), reverse_steps.rend ());
353+ }
354+
355+ inline bool AccessPathObj::PathEqual (const AccessPath& other) const {
356+ return PathEqual (this , other.get ());
357+ }
358+
359+ inline bool AccessPathObj::IsPrefixOf (const AccessPath& other) const {
360+ if (this ->depth > other->depth ) {
361+ return false ;
362+ }
363+ const AccessPathObj* rhs_path = other.get ();
364+ while (rhs_path->depth > this ->depth ) {
365+ TVM_FFI_ICHECK (rhs_path->parent .has_value ());
366+ rhs_path = static_cast <const AccessPathObj*>(rhs_path->parent .get ());
367+ }
368+ return PathEqual (this , rhs_path);
369+ }
370+
105371} // namespace reflection
106372} // namespace ffi
107373} // namespace tvm
374+
108375#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_
0 commit comments