Skip to content

Commit 02c0823

Browse files
committed
Refactor code to use new names
1 parent 2041efa commit 02c0823

File tree

13 files changed

+230
-175
lines changed

13 files changed

+230
-175
lines changed

3rdparty/HalideIR

include/tvm/relay/vm/vm.h

Lines changed: 6 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -11,102 +11,13 @@
1111
#include <tvm/relay/expr_functor.h>
1212
#include <tvm/relay/logging.h>
1313
#include <tvm/runtime/memory_manager.h>
14+
#include <tvm/runtime/object.h>
1415

1516
namespace tvm {
1617
namespace relay {
1718
namespace vm {
1819

19-
using runtime::NDArray;
20-
21-
enum struct VMObjectTag {
22-
kTensor,
23-
kClosure,
24-
kDatatype,
25-
kExternalFunc,
26-
};
27-
28-
inline std::string VMObjectTagString(VMObjectTag tag) {
29-
switch (tag) {
30-
case VMObjectTag::kClosure:
31-
return "Closure";
32-
case VMObjectTag::kDatatype:
33-
return "Datatype";
34-
case VMObjectTag::kTensor:
35-
return "Tensor";
36-
case VMObjectTag::kExternalFunc:
37-
return "ExternalFunction";
38-
default:
39-
LOG(FATAL) << "Object tag is not supported.";
40-
return "";
41-
}
42-
}
43-
44-
// TODO(@jroesch): Use intrusive pointer.
45-
struct VMObjectCell {
46-
VMObjectTag tag;
47-
VMObjectCell(VMObjectTag tag) : tag(tag) {}
48-
VMObjectCell() {}
49-
virtual ~VMObjectCell() {}
50-
};
51-
52-
struct VMTensorCell : public VMObjectCell {
53-
tvm::runtime::NDArray data;
54-
VMTensorCell(const tvm::runtime::NDArray& data)
55-
: VMObjectCell(VMObjectTag::kTensor), data(data) {}
56-
};
57-
58-
struct VMObject {
59-
std::shared_ptr<VMObjectCell> ptr;
60-
VMObject(std::shared_ptr<VMObjectCell> ptr) : ptr(ptr) {}
61-
VMObject() : ptr() {}
62-
VMObject(const VMObject& obj) : ptr(obj.ptr) {}
63-
VMObjectCell* operator->() {
64-
return this->ptr.operator->();
65-
}
66-
};
67-
68-
struct VMDatatypeCell : public VMObjectCell {
69-
size_t tag;
70-
std::vector<VMObject> fields;
71-
72-
VMDatatypeCell(size_t tag, const std::vector<VMObject>& fields)
73-
: VMObjectCell(VMObjectTag::kDatatype), tag(tag), fields(fields) {}
74-
};
75-
76-
struct VMClosureCell : public VMObjectCell {
77-
size_t func_index;
78-
std::vector<VMObject> free_vars;
79-
80-
VMClosureCell(size_t func_index, const std::vector<VMObject>& free_vars)
81-
: VMObjectCell(VMObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {}
82-
};
83-
84-
85-
inline VMObject VMTensor(const tvm::runtime::NDArray& data) {
86-
auto ptr = std::make_shared<VMTensorCell>(data);
87-
return std::dynamic_pointer_cast<VMObjectCell>(ptr);
88-
}
89-
90-
inline VMObject VMDatatype(size_t tag, const std::vector<VMObject>& fields) {
91-
auto ptr = std::make_shared<VMDatatypeCell>(tag, fields);
92-
return std::dynamic_pointer_cast<VMObjectCell>(ptr);
93-
}
94-
95-
inline VMObject VMTuple(const std::vector<VMObject>& fields) {
96-
return VMDatatype(0, fields);
97-
}
98-
99-
inline VMObject VMClosure(size_t func_index, std::vector<VMObject> free_vars) {
100-
auto ptr = std::make_shared<VMClosureCell>(func_index, free_vars);
101-
return std::dynamic_pointer_cast<VMObjectCell>(ptr);
102-
}
103-
104-
inline NDArray ToNDArray(const VMObject& obj) {
105-
CHECK(obj.ptr.get());
106-
CHECK(obj.ptr->tag == VMObjectTag::kTensor) << "Expect Tensor, Got " << VMObjectTagString(obj.ptr->tag);
107-
std::shared_ptr<VMTensorCell> o = std::dynamic_pointer_cast<VMTensorCell>(obj.ptr);
108-
return o->data;
109-
}
20+
using namespace tvm::runtime;
11021

11122
enum struct Opcode {
11223
Push,
@@ -235,8 +146,8 @@ struct VirtualMachine {
235146
std::vector<PackedFunc> packed_funcs;
236147
std::vector<VMFunction> functions;
237148
std::vector<VMFrame> frames;
238-
std::vector<VMObject> stack;
239-
std::vector<VMObject> constants;
149+
std::vector<Object> stack;
150+
std::vector<Object> constants;
240151

241152
// Frame State
242153
size_t func_index;
@@ -255,8 +166,8 @@ struct VirtualMachine {
255166
void InvokeGlobal(const VMFunction& func);
256167
void Run();
257168

258-
VMObject Invoke(const VMFunction& func, const std::vector<VMObject>& args);
259-
VMObject Invoke(const GlobalVar& global, const std::vector<VMObject>& args);
169+
Object Invoke(const VMFunction& func, const std::vector<Object>& args);
170+
Object Invoke(const GlobalVar& global, const std::vector<Object>& args);
260171

261172
// Ignore the method that dumps register info at compile-time if debugging
262173
// mode is not enabled.

include/tvm/runtime/c_runtime_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ typedef enum {
8585
kStr = 11U,
8686
kBytes = 12U,
8787
kNDArrayContainer = 13U,
88-
kVMObject = 14U,
88+
kObject = 14U,
8989
// Extension codes for other frameworks to integrate TVM PackedFunc.
9090
// To make sure each framework's id do not conflict, use first and
9191
// last sections to mark ranges.

include/tvm/runtime/ndarray.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*!
22
* Copyright (c) 2017 by Contributors
33
* \file tvm/runtime/ndarray.h
4-
* \brief Abstract device memory management API
4+
* \brief A device-indpendent managed NDArray abstraction.
55
*/
66
#ifndef TVM_RUNTIME_NDARRAY_H_
77
#define TVM_RUNTIME_NDARRAY_H_

include/tvm/runtime/object.h

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*!
2+
* Copyright (c) 2019 by Contributors
3+
* \file tvm/runtime/object.h
4+
* \brief A managed object in the TVM runtime.
5+
*/
6+
#ifndef TVM_RUNTIME_OBJECT_H_
7+
#define TVM_RUNTIME_OBJECT_H_
8+
9+
#include <tvm/runtime/ndarray.h>
10+
11+
namespace tvm {
12+
namespace runtime {
13+
14+
enum struct ObjectTag {
15+
kTensor,
16+
kClosure,
17+
kDatatype,
18+
kExternalFunc
19+
};
20+
21+
std::ostream& operator<<(std::ostream& os, const ObjectTag&);
22+
23+
// TODO(@jroesch): Use intrusive pointer.
24+
struct ObjectCell {
25+
ObjectTag tag;
26+
ObjectCell(ObjectTag tag) : tag(tag) {}
27+
ObjectCell() {}
28+
virtual ~ObjectCell() {}
29+
};
30+
31+
/*!
32+
* \brief A managed object in the TVM runtime.
33+
*
34+
* For example a tuple, list, closure, and so on.
35+
*
36+
* Maintains a reference count for the object.
37+
*/
38+
class Object {
39+
public:
40+
std::shared_ptr<ObjectCell> ptr;
41+
Object(std::shared_ptr<ObjectCell> ptr) : ptr(ptr) {}
42+
Object() : ptr() {}
43+
Object(const Object& obj) : ptr(obj.ptr) {}
44+
ObjectCell* operator->() {
45+
return this->ptr.operator->();
46+
}
47+
};
48+
49+
struct TensorCell : public ObjectCell {
50+
NDArray data;
51+
TensorCell(const NDArray& data)
52+
: ObjectCell(ObjectTag::kTensor), data(data) {}
53+
};
54+
55+
struct DatatypeCell : public ObjectCell {
56+
size_t tag;
57+
std::vector<Object> fields;
58+
59+
DatatypeCell(size_t tag, const std::vector<Object>& fields)
60+
: ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {}
61+
};
62+
63+
struct ClosureCell : public ObjectCell {
64+
size_t func_index;
65+
std::vector<Object> free_vars;
66+
67+
ClosureCell(size_t func_index, const std::vector<Object>& free_vars)
68+
: ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {}
69+
};
70+
71+
Object TensorObj(const NDArray& data);
72+
Object DatatypeObj(size_t tag, const std::vector<Object>& fields);
73+
Object TupleObj(const std::vector<Object>& fields);
74+
Object ClosureObj(size_t func_index, std::vector<Object> free_vars);
75+
NDArray ToNDArray(const Object& obj);
76+
77+
} // namespace runtime
78+
} // namespace tvm
79+
#endif // TVM_RUNTIME_OBJECT_H_

include/tvm/runtime/packed_func.h

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "c_runtime_api.h"
2222
#include "module.h"
2323
#include "ndarray.h"
24+
#include "object.h"
2425
#include "node_base.h"
2526

2627
namespace HalideIR {
@@ -40,12 +41,6 @@ namespace tvm {
4041
// forward declarations
4142
class Integer;
4243

43-
namespace relay {
44-
namespace vm {
45-
struct VMObject;
46-
}
47-
}
48-
4944
namespace runtime {
5045
// forward declarations
5146
class TVMArgs;
@@ -589,7 +584,7 @@ class TVMArgValue : public TVMPODValue_ {
589584
inline operator tvm::Integer() const;
590585
// get internal node ptr, if it is node
591586
inline NodePtr<Node>& node_sptr();
592-
operator relay::vm::VMObject() const;
587+
operator runtime::Object() const;
593588
};
594589

595590
/*!
@@ -724,7 +719,7 @@ class TVMRetValue : public TVMPODValue_ {
724719
return *this;
725720
}
726721

727-
TVMRetValue& operator=(relay::vm::VMObject other);
722+
TVMRetValue& operator=(runtime::Object other);
728723

729724
TVMRetValue& operator=(PackedFunc f) {
730725
this->SwitchToClass(kFuncHandle, f);
@@ -821,7 +816,7 @@ class TVMRetValue : public TVMPODValue_ {
821816
kNodeHandle, *other.template ptr<NodePtr<Node> >());
822817
break;
823818
}
824-
case kVMObject: {
819+
case kObject: {
825820
throw dmlc::Error("here");
826821
}
827822
default: {
@@ -871,7 +866,7 @@ class TVMRetValue : public TVMPODValue_ {
871866
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
872867
break;
873868
}
874-
// case kModuleHandle: delete ptr<relay::vm::VMObject>(); break;
869+
// case kModuleHandle: delete ptr<runtime::Object>(); break;
875870
}
876871
if (type_code_ > kExtBegin) {
877872
#if TVM_RUNTIME_HEADER_ONLY
@@ -901,7 +896,7 @@ inline const char* TypeCode2Str(int type_code) {
901896
case kFuncHandle: return "FunctionHandle";
902897
case kModuleHandle: return "ModuleHandle";
903898
case kNDArrayContainer: return "NDArrayContainer";
904-
case kVMObject: return "VMObject";
899+
case kObject: return "Object";
905900
default: LOG(FATAL) << "unknown type_code="
906901
<< static_cast<int>(type_code); return "";
907902
}

python/tvm/relay/prelude.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def __init__(self, mod):
473473
self.define_list_map()
474474
self.define_list_foldl()
475475
self.define_list_foldr()
476-
# self.define_list_concat()
476+
self.define_list_concat()
477477
self.define_list_filter()
478478
self.define_list_zip()
479479
self.define_list_rev()
@@ -489,9 +489,10 @@ def __init__(self, mod):
489489
self.define_nat_add()
490490
self.define_list_length()
491491
self.define_list_nth()
492+
self.define_list_update()
492493
self.define_list_sum()
493-
self.define_tree_adt()
494494

495+
self.define_tree_adt()
495496
self.define_tree_map()
496497
self.define_tree_size()
497498

src/api/dsl_api.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <dmlc/thread_local.h>
99
#include <tvm/api_registry.h>
1010
#include <tvm/attrs.h>
11-
#include <tvm/relay/vm/vm.h>
1211
#include <vector>
1312
#include <string>
1413
#include <exception>
@@ -74,7 +73,7 @@ struct APIAttrGetter : public AttrVisitor {
7473
found_ref_object = true;
7574
}
7675
}
77-
void Visit(const char* key, relay::vm::VMObject* value) final {
76+
void Visit(const char* key, runtime::Object* value) final {
7877
if (skey == key) {
7978
*ret = value[0];
8079
found_ref_object = true;
@@ -115,7 +114,7 @@ struct APIAttrDir : public AttrVisitor {
115114
void Visit(const char* key, runtime::NDArray* value) final {
116115
names->push_back(key);
117116
}
118-
void Visit(const char* key, relay::vm::VMObject* value) final {
117+
void Visit(const char* key, runtime::Object* value) final {
119118
names->push_back(key);
120119
}
121120
};

0 commit comments

Comments
 (0)