Skip to content

Commit 878da97

Browse files
committed
Support export ADT value in Python
1 parent c4245e3 commit 878da97

File tree

12 files changed

+66
-80
lines changed

12 files changed

+66
-80
lines changed

include/tvm/relay/interpreter.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,21 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
182182
class ConstructorValue;
183183

184184
struct ConstructorValueNode : ValueNode {
185-
Constructor constructor;
185+
int tag;
186186

187187
tvm::Array<Value> fields;
188188

189+
/*! \brief Optional field tracking ADT constructor. */
190+
Constructor constructor;
191+
189192
void VisitAttrs(tvm::AttrVisitor* v) final {
190-
v->Visit("constructor", &constructor);
193+
v->Visit("tag", &tag);
191194
v->Visit("fields", &fields);
192195
}
193196

194-
TVM_DLL static ConstructorValue make(Constructor constructor,
195-
tvm::Array<Value> fields);
197+
TVM_DLL static ConstructorValue make(int tag,
198+
tvm::Array<Value> fields,
199+
Constructor construtor = {});
196200

197201
static constexpr const char* _type_key = "relay.ConstructorValue";
198202
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);

python/tvm/relay/backend/interpreter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222

2323
from . import _backend
24-
from .. import _make, ir_pass
24+
from .. import _make, ir_pass, prelude
2525
from ... import register_func, nd
2626
from ..base import NodeBase, register_relay_node
2727
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
@@ -73,9 +73,9 @@ class Closure(Value):
7373

7474
@register_relay_node
7575
class ConstructorValue(Value):
76-
def __init__(self, constructor, fields, types):
76+
def __init__(self, tag, fields, constructor, types):
7777
self.__init_handle_by_constructor__(
78-
_make.ConstructorValue, constructor, fields, types)
78+
_make.ConstructorValue, tag, fields, constructor, types)
7979

8080

8181
@register_relay_node

python/tvm/relay/backend/vm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def _eval_vm(mod, ctx, *args):
9797
args: List[tvm.NDArray, np.ndarray]
9898
The arguments to evaluate.
9999
"""
100-
101100
mod = optimize(mod)
102101
args = list(args)
103102
assert isinstance(args, list)

python/tvm/relay/prelude.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,6 @@ def load_prelude(self):
491491
def __init__(self, mod):
492492
self.mod = mod
493493
self.load_prelude()
494-
495494
self.define_list_adt()
496495
self.define_list_hd()
497496
self.define_list_tl()

python/tvm/relay/testing/nat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,25 +151,25 @@ def add_nat_definitions(prelude):
151151
# helper functions for working with nats
152152

153153

154-
def count(n):
154+
def count(prelude, n):
155155
"""Takes a ConstructorValue corresponding to a nat ADT
156156
and converts it into a Python integer. This is an example of
157157
using an ADT value in Python.
158158
"""
159159
assert isinstance(n, ConstructorValue)
160-
if n.constructor.name_hint == 'z':
160+
if n.tag == prelude.z.tag:
161161
return 0
162-
assert n.constructor.name_hint == 's'
163-
return 1 + count(n.fields[0])
162+
assert n.tag == prelude.s.tag
163+
return 1 + count(prelude, n.fields[0])
164164

165165

166166
def make_nat_value(prelude, n):
167167
"""The inverse of count(): Given a non-negative Python integer,
168168
constructs a ConstructorValue representing that value as a nat.
169169
"""
170170
if n == 0:
171-
return ConstructorValue(prelude.z, [], [])
172-
return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
171+
return ConstructorValue(prelude.z.tag, [], None, [])
172+
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, [])
173173

174174

175175
def make_nat_expr(prelude, n):

src/relay/backend/interpreter.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
103103
p->stream << "RefValueNode(" << node->value << ")";
104104
});
105105

106-
ConstructorValue ConstructorValueNode::make(Constructor constructor,
107-
tvm::Array<Value> fields) {
106+
ConstructorValue ConstructorValueNode::make(int tag,
107+
tvm::Array<Value> fields,
108+
Constructor constructor) {
108109
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
109-
n->constructor = constructor;
110+
n->tag = tag;
110111
n->fields = fields;
112+
n->constructor = constructor;
111113
return ConstructorValue(n);
112114
}
113115

@@ -117,7 +119,7 @@ TVM_REGISTER_API("relay._make.ConstructorValue")
117119
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
118120
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
119121
tvm::IRPrinter* p) {
120-
p->stream << "ConstructorValueNode(" << node->constructor
122+
p->stream << "ConstructorValueNode(" << node->tag << ","
121123
<< node->fields << ")";
122124
});
123125

@@ -448,7 +450,7 @@ class Interpreter :
448450
"fusing and lowering";
449451
}
450452
if (auto con = call->op.as<ConstructorNode>()) {
451-
return ConstructorValueNode::make(GetRef<Constructor>(con), args);
453+
return ConstructorValueNode::make(con->tag, args);
452454
}
453455
// Now we just evaluate and expect to find a closure.
454456
Value fn_val = Eval(call->op);
@@ -544,9 +546,8 @@ class Interpreter :
544546
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
545547
CHECK(cvn) << "need to be a constructor for match";
546548
CHECK_NE(op->constructor->tag, -1);
547-
CHECK_NE(cvn->constructor->tag, -1);
548-
if (op->constructor->tag == cvn->constructor->tag) {
549-
// todo(M.K.): should use ptr equality but it is broken
549+
CHECK_NE(cvn->tag, -1);
550+
if (op->constructor->tag == cvn->tag) {
550551
CHECK_EQ(op->patterns.size(), cvn->fields.size());
551552
for (size_t i = 0; i < op->patterns.size(); ++i) {
552553
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {

src/relay/backend/vm/compiler.cc

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ struct VMCompilerContext {
8080
ConstTensorShapeMap const_tensor_shape_map;
8181
// List of lowered functions
8282
std::vector<LoweredFunc> lowered_funcs;
83+
// The functions that have been lowered.
84+
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
8385
};
8486

8587
// Compute the constant pool, i.e a mapping from Constant node to constant index.
@@ -184,9 +186,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
184186
size_t registers_num;
185187
CompileEngine engine;
186188

187-
/*! \brief The functions that have been lowered. */
188-
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
189-
190189
/*! \brief Global shared meta data */
191190
VMCompilerContext* context;
192191

@@ -260,7 +259,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
260259

261260
void VisitExpr_(const MatchNode* match_node) {
262261
auto match = GetRef<Match>(match_node);
263-
LOG(FATAL) << "translation of match nodes to the VM is"
262+
LOG(FATAL) << "translation of match nodes to the VM is "
264263
<< "currently unsupported" << std::endl;
265264
}
266265

@@ -280,7 +279,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
280279
}
281280

282281
void VisitExpr_(const GlobalVarNode* gvar) {
283-
LOG(FATAL) << "Global variables should only appear in the call position";
282+
// TODO(wweic): Support Load GlobalVar into a register
283+
LOG(WARNING) << "Loading GlobalVar into register is not yet supported";
284284
}
285285

286286
void VisitExpr_(const IfNode* if_node) {
@@ -405,12 +405,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
405405
// TODO(jroesch): support lowered funcs for multiple targets
406406
CHECK_EQ(cfunc->funcs.size(), 1);
407407
auto op_index = -1;
408-
if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
408+
if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) {
409409
op_index = this->context->lowered_funcs.size();
410410
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
411-
seen_funcs[cfunc->funcs[0]] = op_index;
411+
this->context->seen_funcs[cfunc->funcs[0]] = op_index;
412412
} else {
413-
op_index = seen_funcs[cfunc->funcs[0]];
413+
op_index = this->context->seen_funcs[cfunc->funcs[0]];
414414
}
415415

416416
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
@@ -429,7 +429,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
429429
std::vector<Index> args_registers;
430430

431431
for (auto arg : call_node->args) {
432-
CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
433432
this->VisitExpr(arg);
434433
args_registers.push_back(last_register);
435434
}
@@ -449,18 +448,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
449448
auto func = this->context->module->Lookup(global);
450449
if (IsClosure(func)) {
451450
auto arity = func->params.size();
452-
std::vector<Index> free_var_registers;
453-
for (size_t i = 0; i < arity; ++i) {
454-
free_var_registers.push_back(var_register_map.at(func->params[i]));
455-
}
456-
Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
451+
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
457452
} else {
458453
Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
459454
}
460455
} else if (auto constructor_node = op.as<ConstructorNode>()) {
461456
auto constructor = GetRef<Constructor>(constructor_node);
462-
auto tag = GetConstructorTag(constructor);
463-
Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
457+
Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
458+
NewRegister()));
464459
} else if (auto var_node = op.as<VarNode>()) {
465460
VisitExpr(GetRef<Var>(var_node));
466461
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
@@ -469,18 +464,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
469464
}
470465
}
471466

472-
size_t GetConstructorTag(tvm::relay::Constructor constructor) {
473-
auto it = this->context->tag_map.find(constructor);
474-
if (it != this->context->tag_map.end()) {
475-
return it->second;
476-
} else {
477-
auto tag = this->context->tag_map.size();
478-
this->context->tag_map[constructor] = tag;
479-
this->context->tag_index_map[tag] = constructor;
480-
return tag;
481-
}
482-
}
483-
484467
void VisitExpr_(const FunctionNode* func_node) {
485468
if (!func_node->IsPrimitive()) {
486469
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
@@ -549,7 +532,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
549532
}
550533

551534
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
552-
DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
535+
DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl;
553536
size_t params = func->params.size();
554537
VMCompiler compiler(context);
555538
compiler.Compile(func);

src/relay/backend/vm/vm.cc

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,21 @@ Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
6363
return res;
6464
}
6565

66-
Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
67-
CHECK(module.defined() && type.defined());
66+
Value VMToValue(const relay::Module& module, Object obj) {
67+
CHECK(module.defined());
6868
switch (obj->tag) {
6969
case ObjectTag::kTensor: {
70-
CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
7170
return TensorValueNode::make(ToNDArray(obj));
7271
}
7372
case ObjectTag::kDatatype: {
74-
// const auto* tuple_type
75-
// const auto& data_type = obj.AsDatatype();
73+
const auto& data_type = obj.AsDatatype();
7674

77-
// tvm::Array<Value> fields;
78-
// for (size_t i = 0; i < data_type->fields.size(); ++i) {
79-
// fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
80-
// }
75+
tvm::Array<Value> fields;
76+
for (size_t i = 0; i < data_type->fields.size(); ++i) {
77+
fields.push_back(VMToValue(module, data_type->fields[i]));
78+
}
8179

82-
// return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
83-
LOG(FATAL) << "fix me";
80+
return ConstructorValueNode::make(data_type->tag, fields);
8481
}
8582
default:
8683
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
@@ -141,8 +138,6 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
141138
LOG(FATAL) << "expected function or module";
142139
}
143140

144-
auto return_type = module->Lookup(module->entry_func)->ret_type;
145-
146141
std::vector<Object> vm_args;
147142
for (auto i = 3; i < args.size(); i++) {
148143
Object obj = args[i];
@@ -151,7 +146,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
151146

152147
auto result = EvaluateModule(module, {ctx}, vm_args);
153148
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
154-
*ret = VMToValue(module, return_type, result);
149+
*ret = VMToValue(module, result);
155150
});
156151

157152
} // namespace vm

tests/python/relay/test_adt.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
2222
from tvm.relay import testing, create_executor
2323
from tvm.relay.prelude import Prelude
24-
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
24+
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
2525

2626
mod = relay.Module()
2727
p = Prelude(mod)
2828
add_nat_definitions(p)
2929

30+
def count(e):
31+
return count_(p, e)
32+
3033
ctx = tvm.context("llvm", 0)
3134
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
3235

@@ -91,18 +94,18 @@ def to_list(l):
9194
val = l
9295
ret = []
9396
while True:
94-
if val.constructor.name_hint == 'cons':
97+
if val.tag == p.cons.tag:
9598
ret.append(val.fields[0])
9699
val = val.fields[1]
97100
else:
98-
assert val.constructor.name_hint == 'nil'
101+
assert val.tag == p.nil.tag
99102
break
100103
return ret
101104

102105
def tree_to_dict(t):
103106
assert isinstance(t, ConstructorValue)
104107
ret = {}
105-
assert t.constructor.name_hint == 'rose'
108+
assert t.tag == p.rose.tag
106109
ret['member'] = t.fields[0]
107110
ret['children'] = []
108111
for subtree in to_list(t.fields[1]):

tests/python/relay/test_backend_interpreter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def test_function_taking_adt_ref_tuple():
183183
prelude = relay.prelude.Prelude(mod)
184184
intrp = create_executor("debug", mod)
185185

186-
nil_value = ConstructorValue(prelude.nil, [], [])
187-
cons_value = ConstructorValue(prelude.cons, [
186+
nil_value = ConstructorValue(prelude.nil.tag, [], [])
187+
cons_value = ConstructorValue(prelude.cons.tag, [
188188
TensorValue(np.random.rand(1, 10).astype('float32')),
189189
nil_value
190190
], [relay.TensorType((1, 10), 'float32')])
@@ -197,16 +197,16 @@ def test_function_taking_adt_ref_tuple():
197197
id_func = intrp.evaluate(prelude.id)
198198

199199
res_nil = id_func(nil_value)
200-
assert res_nil.constructor == nil_value.constructor
200+
assert res_nil.tag == nil_value.tag
201201
assert len(res_nil.fields) == 0
202202

203203
res_cons = id_func(cons_value)
204-
assert res_cons.constructor == cons_value.constructor
204+
assert res_cons.tag == cons_value.tag
205205
assert len(res_cons.fields) == len(cons_value.fields)
206206
tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
207207
cons_value.fields[0].asnumpy())
208208
assert isinstance(res_cons.fields[1], ConstructorValue)
209-
assert res_cons.fields[1].constructor == prelude.nil
209+
assert res_cons.fields[1].tag == prelude.nil.tag
210210
assert len(res_cons.fields[1].fields) == 0
211211

212212
res_ref = id_func(ref_value)

0 commit comments

Comments
 (0)