@@ -73,6 +73,8 @@ struct VMCompilerContext {
7373 ConstTensorShapeMap const_tensor_shape_map;
7474 // List of lowered functions
7575 std::vector<LoweredFunc> lowered_funcs;
76+ // The functions that have been lowered.
77+ std::unordered_map<LoweredFunc, size_t , NodeHash, NodeEqual> seen_funcs;
7678};
7779
7880// Compute the constant pool, i.e a mapping from Constant node to constant index.
@@ -177,9 +179,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
177179 size_t registers_num;
178180 CompileEngine engine;
179181
180- /* ! \brief The functions that have been lowered. */
181- std::unordered_map<LoweredFunc, size_t , NodeHash, NodeEqual> seen_funcs;
182-
183182 /* ! \brief Global shared meta data */
184183 VMCompilerContext* context;
185184
@@ -253,7 +252,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
253252
254253 void VisitExpr_ (const MatchNode* match_node) {
255254 auto match = GetRef<Match>(match_node);
256- LOG (FATAL) << " translation of match nodes to the VM is"
255+ LOG (FATAL) << " translation of match nodes to the VM is "
257256 << " currently unsupported" << std::endl;
258257 }
259258
@@ -273,7 +272,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
273272 }
274273
275274 void VisitExpr_ (const GlobalVarNode* gvar) {
276- LOG (FATAL) << " Global variables should only appear in the call position" ;
275+ // TODO(wweic): Support Load GlobalVar into a register
276+ LOG (WARNING) << " Loading GlobalVar into register is not yet supported" ;
277277 }
278278
279279 void VisitExpr_ (const IfNode* if_node) {
@@ -370,12 +370,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
370370 // TODO(jroesch): support lowered funcs for multiple targets
371371 CHECK_EQ (cfunc->funcs .size (), 1 );
372372 auto op_index = -1 ;
373- if (seen_funcs.find (cfunc->funcs [0 ]) == seen_funcs.end ()) {
373+ if (this -> context -> seen_funcs .find (cfunc->funcs [0 ]) == this -> context -> seen_funcs .end ()) {
374374 op_index = this ->context ->lowered_funcs .size ();
375375 this ->context ->lowered_funcs .push_back (cfunc->funcs [0 ]);
376- seen_funcs[cfunc->funcs [0 ]] = op_index;
376+ this -> context -> seen_funcs [cfunc->funcs [0 ]] = op_index;
377377 } else {
378- op_index = seen_funcs[cfunc->funcs [0 ]];
378+ op_index = this -> context -> seen_funcs [cfunc->funcs [0 ]];
379379 }
380380
381381 // If Tensor, 1
@@ -396,7 +396,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
396396 std::vector<Index> args_registers;
397397
398398 for (auto arg : call_node->args ) {
399- CHECK (arg.as <VarNode>()) << " found: " << AsText (arg, false ) << std::endl << arg;
400399 this ->VisitExpr (arg);
401400 args_registers.push_back (last_register);
402401 }
@@ -416,18 +415,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
416415 auto func = this ->context ->module ->Lookup (global);
417416 if (IsClosure (func)) {
418417 auto arity = func->params .size ();
419- std::vector<Index> free_var_registers;
420- for (size_t i = 0 ; i < arity; ++i) {
421- free_var_registers.push_back (var_register_map.at (func->params [i]));
422- }
423- Emit (Instruction::AllocClosure (it->second , arity, free_var_registers, NewRegister ()));
418+ Emit (Instruction::AllocClosure (it->second , arity, args_registers, NewRegister ()));
424419 } else {
425420 Emit (Instruction::Invoke (it->second , args_registers, NewRegister ()));
426421 }
427422 } else if (auto constructor_node = op.as <ConstructorNode>()) {
428423 auto constructor = GetRef<Constructor>(constructor_node);
429- auto tag = GetConstructorTag (constructor);
430- Emit ( Instruction::AllocDatatype (tag, call_node-> args . size (), args_registers, NewRegister ()));
424+ Emit ( Instruction::AllocDatatype (constructor-> tag , call_node-> args . size (), args_registers,
425+ NewRegister ()));
431426 } else if (auto var_node = op.as <VarNode>()) {
432427 VisitExpr (GetRef<Var>(var_node));
433428 Emit (Instruction::InvokeClosure (last_register, args_registers, NewRegister ()));
@@ -436,18 +431,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
436431 }
437432 }
438433
439- size_t GetConstructorTag (tvm::relay::Constructor constructor) {
440- auto it = this ->context ->tag_map .find (constructor);
441- if (it != this ->context ->tag_map .end ()) {
442- return it->second ;
443- } else {
444- auto tag = this ->context ->tag_map .size ();
445- this ->context ->tag_map [constructor] = tag;
446- this ->context ->tag_index_map [tag] = constructor;
447- return tag;
448- }
449- }
450-
451434 void VisitExpr_ (const FunctionNode* func_node) {
452435 if (!func_node->IsPrimitive ()) {
453436 LOG (FATAL) << " local functions should have been removed by lambda lifting:" << std::endl
@@ -516,7 +499,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
516499}
517500
518501VMFunction CompileFunc (VMCompilerContext* context, const GlobalVar& var, const Function& func) {
519- DLOG (INFO) << " CompileFunc: " << std::endl << AsText (func, false ) << std::endl;
502+ DLOG (INFO) << " CompileFunc: " << var << std::endl << AsText (func, false ) << std::endl;
520503 size_t params = func->params .size ();
521504 VMCompiler compiler (context);
522505 compiler.Compile (func);
0 commit comments