Skip to content

Commit 3b1f3c4

Browse files
authored
Merge pull request apache#5 from gbonik/traced_object
TracedObject wrapper that tracks an ObjectPath
2 parents 98b81d7 + 8739c4f commit 3b1f3c4

File tree

2 files changed

+527
-0
lines changed

2 files changed

+527
-0
lines changed

src/script/printer/traced_object.h

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_NODE_TRACED_OBJECT_H_
21+
#define TVM_NODE_TRACED_OBJECT_H_
22+
23+
#include <tvm/node/object_path.h>
24+
#include <tvm/node/reflection.h>
25+
#include <tvm/runtime/object.h>
26+
27+
#include <string>
28+
#include <utility>
29+
30+
namespace tvm {
31+
32+
template <typename RefT>
33+
class TracedObject;
34+
template <typename K, typename V>
35+
class TracedMap;
36+
template <typename T>
37+
class TracedArray;
38+
template <typename T>
39+
class TracedOptional;
40+
template <typename T>
41+
class TracedBasicValue;
42+
43+
namespace detail {
44+
45+
template <typename T, bool IsObject = std::is_base_of<ObjectRef, T>::value>
46+
struct TracedObjectWrapperSelector;
47+
48+
template <typename T>
49+
struct TracedObjectWrapperSelector<T, false> {
50+
using Type = TracedBasicValue<T>;
51+
};
52+
53+
template <typename T>
54+
struct TracedObjectWrapperSelector<T, true> {
55+
using Type = TracedObject<T>;
56+
};
57+
58+
template <typename K, typename V>
59+
struct TracedObjectWrapperSelector<Map<K, V>, true> {
60+
using Type = TracedMap<K, V>;
61+
};
62+
63+
template <typename T>
64+
struct TracedObjectWrapperSelector<Array<T>, true> {
65+
using Type = TracedArray<T>;
66+
};
67+
68+
template <typename T>
69+
struct TracedObjectWrapperSelector<Optional<T>, true> {
70+
using Type = TracedOptional<T>;
71+
};
72+
73+
} // namespace detail
74+
75+
template <typename RefT>
76+
class TracedObject {
77+
using ObjectType = typename RefT::ContainerType;
78+
79+
public:
80+
explicit TracedObject(const RefT& object_ref, ObjectPath path)
81+
: ref_(object_ref), path_(std::move(path)) {}
82+
83+
template <typename DerivedRef>
84+
TracedObject(const TracedObject<DerivedRef>& derived)
85+
: ref_(derived.Get()), path_(derived.GetPath()) {}
86+
87+
template <typename T, typename BaseType>
88+
typename detail::TracedObjectWrapperSelector<T>::Type GetAttr(T BaseType::*member_ptr) const {
89+
using WrapperType = typename detail::TracedObjectWrapperSelector<T>::Type;
90+
const ObjectType* node = static_cast<const ObjectType*>(ref_.get());
91+
const T& attr = node->*member_ptr;
92+
const char* attr_key = ICHECK_NOTNULL(GetAttrKeyByAddress(node, &attr));
93+
return WrapperType(attr, path_->Attr(attr_key));
94+
}
95+
96+
const RefT& Get() const { return ref_; }
97+
98+
template <typename RefU>
99+
bool IsInstance() const {
100+
return ref_->template IsInstance<typename RefU::ContainerType>();
101+
}
102+
103+
bool defined() const { return ref_.defined(); }
104+
105+
template <typename U>
106+
TracedObject<U> Downcast() const {
107+
return TracedObject<U>(tvm::runtime::Downcast<U>(ref_), path_);
108+
}
109+
110+
template <typename RefU>
111+
TracedOptional<RefU> TryDowncast() const {
112+
if (ref_->template IsInstance<typename RefU::ContainerType>()) {
113+
return Downcast<RefU>();
114+
} else {
115+
return TracedOptional<RefU>(NullOpt, path_);
116+
}
117+
}
118+
119+
const ObjectPath& GetPath() const { return path_; }
120+
121+
private:
122+
RefT ref_;
123+
ObjectPath path_;
124+
};
125+
126+
template <typename K, typename V>
127+
class TracedMapIterator {
128+
public:
129+
using WrappedV = typename detail::TracedObjectWrapperSelector<V>::Type;
130+
using MapIter = typename Map<K, V>::iterator;
131+
132+
using iterator_category = std::bidirectional_iterator_tag;
133+
using difference_type = ptrdiff_t;
134+
using value_type = const std::pair<K, WrappedV>;
135+
using pointer = value_type*;
136+
using reference = value_type;
137+
138+
explicit TracedMapIterator(MapIter iter, ObjectPath map_path)
139+
: iter_(iter), map_path_(std::move(map_path)) {}
140+
141+
bool operator==(const TracedMapIterator& other) const { return iter_ == other.iter_; }
142+
143+
bool operator!=(const TracedMapIterator& other) const { return iter_ != other.iter_; }
144+
145+
pointer operator->() const = delete;
146+
147+
reference operator*() const {
148+
auto kv = *iter_;
149+
return std::make_pair(kv.first, WrappedV(kv.second, map_path_->MapValue(kv.first)));
150+
}
151+
152+
TracedMapIterator& operator++() {
153+
++iter_;
154+
return *this;
155+
}
156+
157+
TracedMapIterator operator++(int) {
158+
TracedMapIterator copy = *this;
159+
++(*this);
160+
return copy;
161+
}
162+
163+
private:
164+
MapIter iter_;
165+
ObjectPath map_path_;
166+
};
167+
168+
template <typename K, typename V>
169+
class TracedMap {
170+
public:
171+
using WrappedV = typename detail::TracedObjectWrapperSelector<V>::Type;
172+
173+
using iterator = TracedMapIterator<K, V>;
174+
175+
explicit TracedMap(Map<K, V> map, ObjectPath path)
176+
: map_(std::move(map)), path_(std::move(path)) {}
177+
178+
WrappedV at(const K& key) const {
179+
auto it = map_.find(key);
180+
ICHECK(it != map_.end()) << "No such key in Map";
181+
auto kv = *it;
182+
return WrappedV(kv.second, path_->MapValue(kv.first));
183+
}
184+
185+
const Map<K, V>& Get() const { return map_; }
186+
187+
const ObjectPath& GetPath() const { return path_; }
188+
189+
iterator begin() const { return iterator(map_.begin(), path_); }
190+
191+
iterator end() const { return iterator(map_.end(), path_); }
192+
193+
bool empty() const { return map_.empty(); }
194+
195+
private:
196+
Map<K, V> map_;
197+
ObjectPath path_;
198+
};
199+
200+
template <typename T>
201+
class TracedArrayIterator {
202+
public:
203+
using WrappedT = typename detail::TracedObjectWrapperSelector<T>::Type;
204+
205+
using difference_type = ptrdiff_t;
206+
using value_type = WrappedT;
207+
using pointer = WrappedT*;
208+
using reference = WrappedT&;
209+
using iterator_category = std::random_access_iterator_tag;
210+
211+
explicit TracedArrayIterator(Array<T> array, size_t index, ObjectPath array_path)
212+
: array_(array), index_(index), array_path_(array_path) {}
213+
214+
TracedArrayIterator& operator++() {
215+
++index_;
216+
return *this;
217+
}
218+
TracedArrayIterator& operator--() {
219+
--index_;
220+
return *this;
221+
}
222+
TracedArrayIterator operator++(int) {
223+
TracedArrayIterator copy = *this;
224+
++index_;
225+
return copy;
226+
}
227+
TracedArrayIterator operator--(int) {
228+
TracedArrayIterator copy = *this;
229+
--index_;
230+
return copy;
231+
}
232+
233+
TracedArrayIterator operator+(difference_type offset) const {
234+
return TracedArrayIterator(array_, index_ + offset, array_path_);
235+
}
236+
237+
TracedArrayIterator operator-(difference_type offset) const {
238+
return TracedArrayIterator(array_, index_ - offset, array_path_);
239+
}
240+
241+
difference_type operator-(const TracedArrayIterator& rhs) const { return index_ - rhs.index_; }
242+
243+
bool operator==(TracedArrayIterator other) const {
244+
return array_.get() == other.array_.get() && index_ == other.index_;
245+
}
246+
bool operator!=(TracedArrayIterator other) const { return !(*this == other); }
247+
value_type operator*() const { return WrappedT(array_[index_], array_path_->ArrayIndex(index_)); }
248+
249+
bool empty() const { return array_.empty(); }
250+
251+
private:
252+
Array<T> array_;
253+
size_t index_;
254+
ObjectPath array_path_;
255+
};
256+
257+
template <typename T>
258+
class TracedArray {
259+
public:
260+
using WrappedT = typename detail::TracedObjectWrapperSelector<T>::Type;
261+
262+
using iterator = TracedArrayIterator<T>;
263+
264+
explicit TracedArray(Array<T> array, ObjectPath path)
265+
: array_(std::move(array)), path_(std::move(path)) {}
266+
267+
const Array<T>& Get() const { return array_; }
268+
269+
const ObjectPath& GetPath() const { return path_; }
270+
271+
WrappedT operator[](size_t index) const {
272+
return WrappedT(array_[index], path_->ArrayIndex(index));
273+
}
274+
275+
iterator begin() const { return iterator(array_, 0, path_); }
276+
277+
iterator end() const { return iterator(array_, array_.size(), path_); }
278+
279+
bool empty() const { return array_.empty(); }
280+
281+
size_t size() const { return array_.size(); }
282+
283+
private:
284+
Array<T> array_;
285+
ObjectPath path_;
286+
};
287+
288+
template <typename T>
289+
class TracedOptional {
290+
public:
291+
using WrappedT = typename detail::TracedObjectWrapperSelector<T>::Type;
292+
293+
TracedOptional(const WrappedT& value) // NOLINT(runtime/explicit)
294+
: optional_(value.Get().defined() ? value.Get() : Optional<T>(NullOpt)),
295+
path_(value.GetPath()) {}
296+
297+
explicit TracedOptional(Optional<T> optional, ObjectPath path)
298+
: optional_(std::move(optional)), path_(std::move(path)) {}
299+
300+
const Optional<T>& Get() const { return optional_; }
301+
302+
const ObjectPath& GetPath() const { return path_; }
303+
304+
bool defined() const { return optional_.defined(); }
305+
306+
WrappedT value() const { return WrappedT(optional_.value(), path_); }
307+
308+
explicit operator bool() const { return optional_.defined(); }
309+
310+
private:
311+
Optional<T> optional_;
312+
ObjectPath path_;
313+
};
314+
315+
template <typename T>
316+
class TracedBasicValue {
317+
public:
318+
explicit TracedBasicValue(const T& value, ObjectPath path)
319+
: value_(value), path_(std::move(path)) {}
320+
321+
const T& Get() const { return value_; }
322+
323+
const ObjectPath& GetPath() const { return path_; }
324+
325+
template <typename F>
326+
typename detail::TracedObjectWrapperSelector<typename std::result_of<F(const T&)>::type>::Type
327+
ApplyFunc(F&& f) const {
328+
return MakeTraced(f(value_), path_);
329+
}
330+
331+
private:
332+
T value_;
333+
ObjectPath path_;
334+
};
335+
336+
template <typename RefT>
337+
typename detail::TracedObjectWrapperSelector<RefT>::Type MakeTraced(const RefT& object) {
338+
using WrappedT = typename detail::TracedObjectWrapperSelector<RefT>::Type;
339+
return WrappedT(object, ObjectPath::Root());
340+
}
341+
342+
template <typename RefT>
343+
typename detail::TracedObjectWrapperSelector<RefT>::Type MakeTraced(const RefT& object,
344+
ObjectPath path) {
345+
using WrappedT = typename detail::TracedObjectWrapperSelector<RefT>::Type;
346+
return WrappedT(object, std::move(path));
347+
}
348+
349+
} // namespace tvm
350+
351+
#endif // TVM_NODE_TRACED_OBJECT_H_

0 commit comments

Comments
 (0)