@@ -140,6 +140,9 @@ void ExprVisitor::VisitVarBinding(const VarBinding& binding) {
140140
141141void ExprVisitor::VisitMatchShape (const MatchShape& binding) {
142142 this ->VisitExpr (binding->value );
143+ // TODO(ziheng): should we change pattern from
144+ // Array<PrimExpr> to ShapeExpr?
145+ this ->VisitExpr (ShapeExpr (binding->pattern ));
143146}
144147
145148void ExprVisitor::VisitBindingBlock (const BindingBlock& block) {
@@ -321,50 +324,53 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) {
321324
322325Type ExprMutator::VisitType (const Type& t) { return t; }
323326
324- void ExprMutator::VisitBinding (const Binding& binding) {
327+ void ExprMutator::VisitBinding (const Binding& binding, IRBuilder& builder ) {
325328 Binding new_binding;
326329 if (binding.as <VarBindingNode>()) {
327- this ->VisitVarBinding (Downcast<VarBinding>(binding), this -> irbuilder_ );
330+ this ->VisitVarBinding (Downcast<VarBinding>(binding), builder );
328331 } else if (binding.as <MatchShapeNode>()) {
329- this ->VisitMatchShape (Downcast<MatchShape>(binding), this -> irbuilder_ );
332+ this ->VisitMatchShape (Downcast<MatchShape>(binding), builder );
330333 } else {
331334 LOG (FATAL) << " Wrong type." ;
332335 }
333336}
334337
335- Var ExprMutator::VisitVarBinding (const VarBinding& binding, IRBuilder& ir_builder ) {
338+ Var ExprMutator::VisitVarBinding (const VarBinding& binding, IRBuilder& builder ) {
336339 Expr new_value = this ->Mutate (binding->value );
337340 if (!binding->var .as <DataflowVarNode>()) {
338- return ir_builder ->EmitOutput (new_value);
341+ return builder ->EmitOutput (new_value);
339342 } else {
340- return ir_builder ->Emit (Downcast<Call>(new_value));
343+ return builder ->Emit (Downcast<Call>(new_value));
341344 }
342345}
343346
344- void ExprMutator::VisitMatchShape (const MatchShape& binding, IRBuilder& ir_builder ) {
347+ void ExprMutator::VisitMatchShape (const MatchShape& binding, IRBuilder& builder ) {
345348 this ->Mutate (binding->value );
349+ this ->Mutate (ShapeExpr (binding->pattern ));
346350}
347351
348352BindingBlock ExprMutator::VisitBindingBlock (const BindingBlock& block) {
349353 if (block.as <DataflowBlockNode>()) {
350354 return this ->VisitDataflowBlock (Downcast<DataflowBlock>(block));
351355 } else {
352- // TODO
353- return block;
356+ this ->builder_ = IRBuilderNode::Create ();
357+ for (auto binding : block->bindings ) {
358+ this ->VisitBinding (binding, this ->builder_ );
359+ }
360+ auto blocks = this ->builder_ ->GetBlocks ();
361+ return blocks.back ();
354362 }
355363}
356364
357365BindingBlock ExprMutator::VisitDataflowBlock (const DataflowBlock& block) {
358- this ->irbuilder_ = LazyIRBuilderNode::Create (block);
366+ this ->builder_ = LazyIRBuilderNode::Create (block);
359367 {
360- With<DataflowScope> scope (this ->irbuilder_ );
368+ With<DataflowScope> scope (this ->builder_ );
361369 for (auto binding : block->bindings ) {
362- if (binding.as <VarBindingNode>()) {
363- this ->VisitVarBinding (Downcast<VarBinding>(binding), this ->irbuilder_ );
364- }
370+ this ->VisitBinding (binding, this ->builder_ );
365371 }
366372 }
367- return this ->irbuilder_ ->GetBlocks ().back ();
373+ return this ->builder_ ->GetBlocks ().back ();
368374}
369375
370376Expr ExprMutator::VisitExpr (const Expr& expr) {
@@ -377,27 +383,27 @@ Expr ExprMutator::VisitExpr(const Expr& expr) {
377383// DataflowMutator
378384
379385BindingBlock DataflowMutator::VisitDataflowBlock (const DataflowBlock& block) {
380- this ->irbuilder_ = LazyIRBuilderNode::Create (block);
386+ this ->builder_ = LazyIRBuilderNode::Create (block);
381387 {
382- With<DataflowScope> scope (this ->irbuilder_ );
388+ With<DataflowScope> scope (this ->builder_ );
383389 for (auto binding : block->bindings ) {
384390 if (auto * var_binding = binding.as <VarBindingNode>()) {
385- Var var = this ->VisitVarBinding (Downcast<VarBinding>(binding), this ->irbuilder_ );
391+ Var var = this ->VisitVarBinding (Downcast<VarBinding>(binding), this ->builder_ );
386392 this ->pre_post_var_map_ [var_binding->var ] = var;
387393 }
388394 }
389395 }
390- return this ->irbuilder_ ->GetBlocks ().back ();
396+ return this ->builder_ ->GetBlocks ().back ();
391397}
392398
393- Var DataflowMutator::VisitVarBinding (const VarBinding& binding, IRBuilder& ir_builder ) {
399+ Var DataflowMutator::VisitVarBinding (const VarBinding& binding, IRBuilder& builder ) {
394400 Expr new_value = this ->Mutate (binding->value );
395401 Var new_var;
396402 if (new_value.as <CallNode>()) {
397- new_var = ir_builder ->Emit (Downcast<Call>(new_value));
403+ new_var = builder ->Emit (Downcast<Call>(new_value));
398404 }
399405 if (!binding->var .as <DataflowVarNode>()) {
400- new_var = ir_builder ->EmitOutput (new_value);
406+ new_var = builder ->EmitOutput (new_value);
401407 }
402408 pre_post_var_map_[binding->var ] = new_var;
403409 return new_var;
@@ -406,9 +412,9 @@ Var DataflowMutator::VisitVarBinding(const VarBinding& binding, IRBuilder& ir_bu
406412Expr DataflowMutator::LookupVar (Var var) {
407413 auto it = pre_post_var_map_.find (var);
408414 if (it != pre_post_var_map_.end ()) {
409- return irbuilder_ ->LookupVar (it->first );
415+ return builder_ ->LookupVar (it->first );
410416 } else {
411- return irbuilder_ ->LookupVar (var);
417+ return builder_ ->LookupVar (var);
412418 }
413419}
414420} // namespace relax
0 commit comments