@@ -80,6 +80,8 @@ struct VMCompilerContext {
8080  ConstTensorShapeMap const_tensor_shape_map;
8181  //  List of lowered functions
8282  std::vector<LoweredFunc> lowered_funcs;
83+   //  The functions that have been lowered.
84+   std::unordered_map<LoweredFunc, size_t , NodeHash, NodeEqual> seen_funcs;
8385};
8486
8587//  Compute the constant pool, i.e a mapping from Constant node to constant index.
@@ -184,9 +186,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
184186  size_t  registers_num;
185187  CompileEngine engine;
186188
187-   /* ! \brief The functions that have been lowered. */ 
188-   std::unordered_map<LoweredFunc, size_t , NodeHash, NodeEqual> seen_funcs;
189- 
190189  /* ! \brief Global shared meta data */ 
191190  VMCompilerContext* context;
192191
@@ -260,7 +259,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
260259
261260  void  VisitExpr_ (const  MatchNode* match_node) {
262261    auto  match = GetRef<Match>(match_node);
263-     LOG (FATAL) << " translation of match nodes to the VM is" 
262+     LOG (FATAL) << " translation of match nodes to the VM is  " 
264263               << " currently unsupported" 
265264  }
266265
@@ -280,7 +279,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
280279  }
281280
282281  void  VisitExpr_ (const  GlobalVarNode* gvar) {
283-     LOG (FATAL) << " Global variables should only appear in the call position" 
282+     //  TODO(wweic): Support Load GlobalVar into a register
283+     LOG (WARNING) << " Loading GlobalVar into register is not yet supported" 
284284  }
285285
286286  void  VisitExpr_ (const  IfNode* if_node) {
@@ -405,12 +405,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
405405    //  TODO(jroesch): support lowered funcs for multiple targets
406406    CHECK_EQ (cfunc->funcs .size (), 1 );
407407    auto  op_index = -1 ;
408-     if  (seen_funcs.find (cfunc->funcs [0 ]) == seen_funcs.end ()) {
408+     if  (this -> context -> seen_funcs .find (cfunc->funcs [0 ]) == this -> context -> seen_funcs .end ()) {
409409      op_index = this ->context ->lowered_funcs .size ();
410410      this ->context ->lowered_funcs .push_back (cfunc->funcs [0 ]);
411-       seen_funcs[cfunc->funcs [0 ]] = op_index;
411+       this -> context -> seen_funcs [cfunc->funcs [0 ]] = op_index;
412412    } else  {
413-       op_index = seen_funcs[cfunc->funcs [0 ]];
413+       op_index = this -> context -> seen_funcs [cfunc->funcs [0 ]];
414414    }
415415
416416    Emit (Instruction::InvokePacked (op_index, arity, return_val_count, unpacked_arg_regs));
@@ -429,7 +429,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
429429    std::vector<Index> args_registers;
430430
431431    for  (auto  arg : call_node->args ) {
432-       CHECK (arg.as <VarNode>()) << " found: " AsText (arg, false ) << std::endl << arg;
433432      this ->VisitExpr (arg);
434433      args_registers.push_back (last_register);
435434    }
@@ -449,18 +448,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
449448      auto  func = this ->context ->module ->Lookup (global);
450449      if  (IsClosure (func)) {
451450        auto  arity = func->params .size ();
452-         std::vector<Index> free_var_registers;
453-         for  (size_t  i = 0 ; i < arity; ++i) {
454-           free_var_registers.push_back (var_register_map.at (func->params [i]));
455-         }
456-         Emit (Instruction::AllocClosure (it->second , arity, free_var_registers, NewRegister ()));
451+         Emit (Instruction::AllocClosure (it->second , arity, args_registers, NewRegister ()));
457452      } else  {
458453        Emit (Instruction::Invoke (it->second , args_registers, NewRegister ()));
459454      }
460455    } else  if  (auto  constructor_node = op.as <ConstructorNode>()) {
461456      auto  constructor = GetRef<Constructor>(constructor_node);
462-       auto   tag =  GetConstructorTag (constructor); 
463-       Emit ( Instruction::AllocDatatype (tag, call_node-> args . size (), args_registers,  NewRegister ()));
457+       Emit ( Instruction::AllocDatatype (constructor-> tag , call_node-> args . size (), args_registers, 
458+                                        NewRegister ()));
464459    } else  if  (auto  var_node = op.as <VarNode>()) {
465460      VisitExpr (GetRef<Var>(var_node));
466461      Emit (Instruction::InvokeClosure (last_register, args_registers, NewRegister ()));
@@ -469,18 +464,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
469464    }
470465  }
471466
472-   size_t  GetConstructorTag (tvm::relay::Constructor constructor) {
473-     auto  it = this ->context ->tag_map .find (constructor);
474-     if  (it != this ->context ->tag_map .end ()) {
475-       return  it->second ;
476-     } else  {
477-       auto  tag = this ->context ->tag_map .size ();
478-       this ->context ->tag_map [constructor] = tag;
479-       this ->context ->tag_index_map [tag] = constructor;
480-       return  tag;
481-     }
482-   }
483- 
484467  void  VisitExpr_ (const  FunctionNode* func_node) {
485468    if  (!func_node->IsPrimitive ()) {
486469      LOG (FATAL) << " local functions should have been removed by lambda lifting:" 
@@ -549,7 +532,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
549532}
550533
551534VMFunction CompileFunc (VMCompilerContext* context, const  GlobalVar& var, const  Function& func) {
552-   DLOG (INFO) << " CompileFunc: " AsText (func, false ) << std::endl;
535+   DLOG (INFO) << " CompileFunc: " var <<  std::endl << AsText (func, false ) << std::endl;
553536  size_t  params = func->params .size ();
554537  VMCompiler compiler (context);
555538  compiler.Compile (func);
0 commit comments