Skip to content

Commit a25ed43

Browse files
committed
Support export ADT value in Python
1 parent 1fdf111 commit a25ed43

File tree

12 files changed

+77
-80
lines changed

12 files changed

+77
-80
lines changed

include/tvm/relay/interpreter.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,21 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
182182
class ConstructorValue;
183183

184184
struct ConstructorValueNode : ValueNode {
185-
Constructor constructor;
185+
int tag;
186186

187187
tvm::Array<Value> fields;
188188

189+
/*! \brief Optional field tracking ADT constructor. */
190+
Constructor constructor;
191+
189192
void VisitAttrs(tvm::AttrVisitor* v) final {
190-
v->Visit("constructor", &constructor);
193+
v->Visit("tag", &tag);
191194
v->Visit("fields", &fields);
192195
}
193196

194-
TVM_DLL static ConstructorValue make(Constructor constructor,
195-
tvm::Array<Value> fields);
197+
TVM_DLL static ConstructorValue make(int tag,
198+
tvm::Array<Value> fields,
199+
Constructor construtor = {});
196200

197201
static constexpr const char* _type_key = "relay.ConstructorValue";
198202
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);

python/tvm/relay/backend/interpreter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222

2323
from . import _backend
24-
from .. import _make, ir_pass
24+
from .. import _make, ir_pass, prelude
2525
from ... import register_func, nd
2626
from ..base import NodeBase, register_relay_node
2727
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
@@ -73,9 +73,9 @@ class Closure(Value):
7373

7474
@register_relay_node
7575
class ConstructorValue(Value):
76-
def __init__(self, constructor, fields, types):
76+
def __init__(self, tag, fields, constructor, types):
7777
self.__init_handle_by_constructor__(
78-
_make.ConstructorValue, constructor, fields, types)
78+
_make.ConstructorValue, tag, fields, constructor, types)
7979

8080

8181
@register_relay_node

python/tvm/relay/backend/vm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ def _eval_vm(mod, ctx, *args):
8282
main_func = ir_pass.eta_expand(main_func.body, mod)
8383

8484
assert isinstance(main_func, Function)
85-
main_func = optimize(mod[mod.entry_func], mod)
86-
mod[mod.entry_func] = main_func
85+
# optimize all functions in the module
86+
for gvar, _ in mod.functions.items():
87+
func_opt = optimize(mod[gvar], mod)
88+
mod[gvar] = func_opt
8789

8890
args = list(args)
8991
assert isinstance(args, list)

python/tvm/relay/prelude.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .op.tensor import add, subtract, equal
2222
from .adt import Constructor, TypeData, Clause, Match
2323
from .adt import PatternConstructor, PatternVar, PatternWildcard
24+
from .module import Module
2425

2526
class Prelude:
2627
"""Contains standard definitions."""
@@ -34,6 +35,8 @@ def define_list_adt(self):
3435
self.nil = Constructor("nil", [], self.l)
3536
self.cons = Constructor("cons", [a, self.l(a)], self.l)
3637
self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])
38+
self.tag2constructor[self.nil.tag] = self.nil
39+
self.tag2constructor[self.cons.tag] = self.cons
3740

3841

3942
def define_list_hd(self):
@@ -336,6 +339,8 @@ def define_optional_adt(self):
336339
self.some = Constructor("some", [a], self.optional)
337340
self.none = Constructor("none", [], self.optional)
338341
self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none])
342+
self.tag2constructor[self.some.tag] = self.some
343+
self.tag2constructor[self.none.tag] = self.none
339344

340345

341346
def define_list_unfoldr(self):
@@ -414,6 +419,7 @@ def define_tree_adt(self):
414419
a = TypeVar("a")
415420
self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree)
416421
self.mod[self.tree] = TypeData(self.tree, [a], [self.rose])
422+
self.tag2constructor[self.rose] = self.rose.tag
417423

418424

419425
def define_tree_map(self):
@@ -503,6 +509,7 @@ def define_iterate(self):
503509

504510
def __init__(self, mod):
505511
self.mod = mod
512+
self.tag2constructor = {}
506513
self.define_list_adt()
507514
self.define_list_hd()
508515
self.define_list_tl()

python/tvm/relay/testing/nat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,25 +151,25 @@ def add_nat_definitions(prelude):
151151
# helper functions for working with nats
152152

153153

154-
def count(n):
154+
def count(prelude, n):
155155
"""Takes a ConstructorValue corresponding to a nat ADT
156156
and converts it into a Python integer. This is an example of
157157
using an ADT value in Python.
158158
"""
159159
assert isinstance(n, ConstructorValue)
160-
if n.constructor.name_hint == 'z':
160+
if n.tag == prelude.z.tag:
161161
return 0
162-
assert n.constructor.name_hint == 's'
163-
return 1 + count(n.fields[0])
162+
assert n.tag == prelude.s.tag
163+
return 1 + count(prelude, n.fields[0])
164164

165165

166166
def make_nat_value(prelude, n):
167167
"""The inverse of count(): Given a non-negative Python integer,
168168
constructs a ConstructorValue representing that value as a nat.
169169
"""
170170
if n == 0:
171-
return ConstructorValue(prelude.z, [], [])
172-
return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
171+
return ConstructorValue(prelude.z.tag, [], None, [])
172+
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, [])
173173

174174

175175
def make_nat_expr(prelude, n):

src/relay/backend/interpreter.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
103103
p->stream << "RefValueNode(" << node->value << ")";
104104
});
105105

106-
ConstructorValue ConstructorValueNode::make(Constructor constructor,
107-
tvm::Array<Value> fields) {
106+
ConstructorValue ConstructorValueNode::make(int tag,
107+
tvm::Array<Value> fields,
108+
Constructor constructor) {
108109
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
109-
n->constructor = constructor;
110+
n->tag = tag;
110111
n->fields = fields;
112+
n->constructor = constructor;
111113
return ConstructorValue(n);
112114
}
113115

@@ -117,7 +119,7 @@ TVM_REGISTER_API("relay._make.ConstructorValue")
117119
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
118120
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
119121
tvm::IRPrinter* p) {
120-
p->stream << "ConstructorValueNode(" << node->constructor
122+
p->stream << "ConstructorValueNode(" << node->tag << ","
121123
<< node->fields << ")";
122124
});
123125

@@ -448,7 +450,7 @@ class Interpreter :
448450
"fusing and lowering";
449451
}
450452
if (auto con = call->op.as<ConstructorNode>()) {
451-
return ConstructorValueNode::make(GetRef<Constructor>(con), args);
453+
return ConstructorValueNode::make(con->tag, args);
452454
}
453455
// Now we just evaluate and expect to find a closure.
454456
Value fn_val = Eval(call->op);
@@ -544,9 +546,8 @@ class Interpreter :
544546
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
545547
CHECK(cvn) << "need to be a constructor for match";
546548
CHECK_NE(op->constructor->tag, -1);
547-
CHECK_NE(cvn->constructor->tag, -1);
548-
if (op->constructor->tag == cvn->constructor->tag) {
549-
// todo(M.K.): should use ptr equality but it is broken
549+
CHECK_NE(cvn->tag, -1);
550+
if (op->constructor->tag == cvn->tag) {
550551
CHECK_EQ(op->patterns.size(), cvn->fields.size());
551552
for (size_t i = 0; i < op->patterns.size(); ++i) {
552553
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {

src/relay/backend/vm/compiler.cc

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ struct VMCompilerContext {
7373
ConstTensorShapeMap const_tensor_shape_map;
7474
// List of lowered functions
7575
std::vector<LoweredFunc> lowered_funcs;
76+
// The functions that have been lowered.
77+
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
7678
};
7779

7880
// Compute the constant pool, i.e a mapping from Constant node to constant index.
@@ -177,9 +179,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
177179
size_t registers_num;
178180
CompileEngine engine;
179181

180-
/*! \brief The functions that have been lowered. */
181-
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
182-
183182
/*! \brief Global shared meta data */
184183
VMCompilerContext* context;
185184

@@ -253,7 +252,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
253252

254253
void VisitExpr_(const MatchNode* match_node) {
255254
auto match = GetRef<Match>(match_node);
256-
LOG(FATAL) << "translation of match nodes to the VM is"
255+
LOG(FATAL) << "translation of match nodes to the VM is "
257256
<< "currently unsupported" << std::endl;
258257
}
259258

@@ -273,7 +272,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
273272
}
274273

275274
void VisitExpr_(const GlobalVarNode* gvar) {
276-
LOG(FATAL) << "Global variables should only appear in the call position";
275+
// TODO(wweic): Support Load GlobalVar into a register
276+
LOG(WARNING) << "Loading GlobalVar into register is not yet supported";
277277
}
278278

279279
void VisitExpr_(const IfNode* if_node) {
@@ -370,12 +370,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
370370
// TODO(jroesch): support lowered funcs for multiple targets
371371
CHECK_EQ(cfunc->funcs.size(), 1);
372372
auto op_index = -1;
373-
if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
373+
if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) {
374374
op_index = this->context->lowered_funcs.size();
375375
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
376-
seen_funcs[cfunc->funcs[0]] = op_index;
376+
this->context->seen_funcs[cfunc->funcs[0]] = op_index;
377377
} else {
378-
op_index = seen_funcs[cfunc->funcs[0]];
378+
op_index = this->context->seen_funcs[cfunc->funcs[0]];
379379
}
380380

381381
// If Tensor, 1
@@ -396,7 +396,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
396396
std::vector<Index> args_registers;
397397

398398
for (auto arg : call_node->args) {
399-
CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
400399
this->VisitExpr(arg);
401400
args_registers.push_back(last_register);
402401
}
@@ -416,18 +415,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
416415
auto func = this->context->module->Lookup(global);
417416
if (IsClosure(func)) {
418417
auto arity = func->params.size();
419-
std::vector<Index> free_var_registers;
420-
for (size_t i = 0; i < arity; ++i) {
421-
free_var_registers.push_back(var_register_map.at(func->params[i]));
422-
}
423-
Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
418+
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
424419
} else {
425420
Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
426421
}
427422
} else if (auto constructor_node = op.as<ConstructorNode>()) {
428423
auto constructor = GetRef<Constructor>(constructor_node);
429-
auto tag = GetConstructorTag(constructor);
430-
Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
424+
Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
425+
NewRegister()));
431426
} else if (auto var_node = op.as<VarNode>()) {
432427
VisitExpr(GetRef<Var>(var_node));
433428
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
@@ -436,18 +431,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
436431
}
437432
}
438433

439-
size_t GetConstructorTag(tvm::relay::Constructor constructor) {
440-
auto it = this->context->tag_map.find(constructor);
441-
if (it != this->context->tag_map.end()) {
442-
return it->second;
443-
} else {
444-
auto tag = this->context->tag_map.size();
445-
this->context->tag_map[constructor] = tag;
446-
this->context->tag_index_map[tag] = constructor;
447-
return tag;
448-
}
449-
}
450-
451434
void VisitExpr_(const FunctionNode* func_node) {
452435
if (!func_node->IsPrimitive()) {
453436
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
@@ -516,7 +499,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
516499
}
517500

518501
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
519-
DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
502+
DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl;
520503
size_t params = func->params.size();
521504
VMCompiler compiler(context);
522505
compiler.Compile(func);

src/relay/backend/vm/vm.cc

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,21 @@ Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
6363
return res;
6464
}
6565

66-
Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
67-
CHECK(module.defined() && type.defined());
66+
Value VMToValue(const relay::Module& module, Object obj) {
67+
CHECK(module.defined());
6868
switch (obj->tag) {
6969
case ObjectTag::kTensor: {
70-
CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
7170
return TensorValueNode::make(ToNDArray(obj));
7271
}
7372
case ObjectTag::kDatatype: {
74-
// const auto* tuple_type
75-
// const auto& data_type = obj.AsDatatype();
73+
const auto& data_type = obj.AsDatatype();
7674

77-
// tvm::Array<Value> fields;
78-
// for (size_t i = 0; i < data_type->fields.size(); ++i) {
79-
// fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
80-
// }
75+
tvm::Array<Value> fields;
76+
for (size_t i = 0; i < data_type->fields.size(); ++i) {
77+
fields.push_back(VMToValue(module, data_type->fields[i]));
78+
}
8179

82-
// return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
83-
LOG(FATAL) << "fix me";
80+
return ConstructorValueNode::make(data_type->tag, fields);
8481
}
8582
default:
8683
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
@@ -141,8 +138,6 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
141138
LOG(FATAL) << "expected function or module";
142139
}
143140

144-
auto return_type = module->Lookup(module->entry_func)->ret_type;
145-
146141
std::vector<Object> vm_args;
147142
for (auto i = 3; i < args.size(); i++) {
148143
Object obj = args[i];
@@ -151,7 +146,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
151146

152147
auto result = EvaluateModule(module, {ctx}, vm_args);
153148
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
154-
*ret = VMToValue(module, return_type, result);
149+
*ret = VMToValue(module, result);
155150
});
156151

157152
} // namespace vm

tests/python/relay/test_adt.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
2222
from tvm.relay import testing, create_executor
2323
from tvm.relay.prelude import Prelude
24-
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
24+
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
2525

2626
mod = relay.Module()
2727
p = Prelude(mod)
2828
add_nat_definitions(p)
2929

30+
def count(e):
31+
return count_(p, e)
32+
3033
ctx = tvm.context("llvm", 0)
3134
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
3235

@@ -91,18 +94,18 @@ def to_list(l):
9194
val = l
9295
ret = []
9396
while True:
94-
if val.constructor.name_hint == 'cons':
97+
if val.tag == p.cons.tag:
9598
ret.append(val.fields[0])
9699
val = val.fields[1]
97100
else:
98-
assert val.constructor.name_hint == 'nil'
101+
assert val.tag == p.nil.tag
99102
break
100103
return ret
101104

102105
def tree_to_dict(t):
103106
assert isinstance(t, ConstructorValue)
104107
ret = {}
105-
assert t.constructor.name_hint == 'rose'
108+
assert t.tag == p.rose.tag
106109
ret['member'] = t.fields[0]
107110
ret['children'] = []
108111
for subtree in to_list(t.fields[1]):

0 commit comments

Comments
 (0)