3333check the attributes of the op and decide if it should be offloaded to DNNL.
3434"""
3535import logging
36- import math
3736
3837import tvm .ir
3938from tvm import relay
4645
4746
4847from ... import _ffi_api
49- from ...dataflow_pattern import wildcard , is_op , is_constant , rewrite , DFPatternCallback , is_expr
48+ from ...dataflow_pattern import wildcard , is_op , rewrite , DFPatternCallback
5049from .register import register_pattern_table
5150
52- import re
5351
5452logger = logging .getLogger ("DNNL" )
5553
@@ -549,18 +547,29 @@ def visit_call(self, call):
549547
550548class DenseReshapeBiasGeluRewrite (DFPatternCallback ):
551549 """
552- A callback to reorder reshape operators when the patten is as below:
553- 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
550+ A callback to reorder reshape operators when the patterns are as below:
551+
552+ Pattern #1:
553+ 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */,
554+ units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
555+ 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
556+ 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63)
557+ /* ty=Tensor[(1, 3136, 64), float32] */;
558+
559+ Pattern #2:
560+ 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */,
561+ units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
554562 2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */;
555- 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77) /* ty=Tensor[(1, 3136, 512), float32] */;
563+ 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77)
564+ /* ty=Tensor[(1, 3136, 512), float32] */;
556565 4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
557566 5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */;
558567 6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
559568 7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */;
560569 8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
561570 """
562571
563- def __init__ (self ):
572+ def __init__ (self , has_gelu = True ):
564573 super (DenseReshapeBiasGeluRewrite , self ).__init__ ()
565574 self .data = wildcard ()
566575 self .weight = wildcard ()
@@ -570,29 +579,30 @@ def __init__(self):
570579 self .const3 = wildcard ()
571580
572581 self .attr_map = {}
582+ self .has_gelu = has_gelu
573583
574584 den = is_op ("nn.dense" )(self .data , self .weight )
575585 re_den = is_op ("reshape" )(den )
576586 added = is_op ("add" )(self .bias , re_den )
577- divisor = is_op ("divide" )(added , self .const1 )
578- val_erf = is_op ("erf" )(divisor )
579- added_erf = is_op ("add" )(val_erf , self .const2 )
580- mul1 = is_op ("multiply" )(added , added_erf )
581- mul2 = is_op ("multiply" )(mul1 , self .const3 )
582- self .pattern = mul2
587+ if self .has_gelu :
588+ divisor = is_op ("divide" )(added , self .const1 )
589+ val_erf = is_op ("erf" )(divisor )
590+ added_erf = is_op ("add" )(val_erf , self .const2 )
591+ mul1 = is_op ("multiply" )(added , added_erf )
592+ mul2 = is_op ("multiply" )(mul1 , self .const3 )
593+ self .pattern = mul2
594+ else :
595+ self .pattern = added
583596
584597 def get_attr (self , pre ):
598+ """Recursively retrieve attributes from reshape operator."""
599+
585600 def visit_func (expr ):
586601 if isinstance (expr , _expr .Call ) and expr .op == relay .op .get ("reshape" ):
587602 new_attrs = {}
588603 for k in expr .attrs .keys ():
589604 new_attrs [k ] = expr .attrs [k ]
590605 self .attr_map ["reshape" ] = new_attrs
591- elif isinstance (expr , _expr .Call ) and expr .op == relay .op .get ("nn.dense" ):
592- new_attrs = {}
593- for k in expr .attrs .keys ():
594- new_attrs [k ] = expr .attrs [k ]
595- self .attr_map ["nn.dense" ] = new_attrs
596606
597607 _analysis .post_order_visit (pre , visit_func )
598608
@@ -602,12 +612,16 @@ def callback(self, pre, post, node_map):
602612 data = node_map [self .data ][0 ]
603613 weight = node_map [self .weight ][0 ]
604614 bias = node_map [self .bias ][0 ]
615+
616+ den = relay .op .nn .dense (data , weight )
617+ added = relay .op .add (bias , den )
618+ if not self .has_gelu :
619+ return relay .op .reshape (added , self .attr_map ["reshape" ]["newshape" ])
620+
605621 const1 = node_map [self .const1 ][0 ]
606622 const2 = node_map [self .const2 ][0 ]
607623 const3 = node_map [self .const3 ][0 ]
608-
609- den = relay .op .nn .dense (data , weight )
610- added = relay .op .add (bias , den )
624+
611625 divisor = relay .op .divide (added , const1 )
612626 val_erf = relay .op .erf (divisor )
613627 added_erf = relay .op .add (val_erf , const2 )
@@ -624,57 +638,9 @@ def rewrite_dense_bias_gelu_reshape_last(mod):
624638 return mod
625639
626640
627- class DenseReshapeBiasRewrite (DFPatternCallback ):
628- """
629- A callback to reorder reshape operators when the patten is as below:
630- 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */, units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
631- 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
632- 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63) /* ty=Tensor[(1, 3136, 64), float32] */;
633- """
634-
635- def __init__ (self ):
636- super (DenseReshapeBiasRewrite , self ).__init__ ()
637- self .data = wildcard ()
638- self .weight = wildcard ()
639- self .bias = wildcard ()
640-
641- self .attr_map = {}
642-
643- den = is_op ("nn.dense" )(self .data , self .weight )
644- re_den = is_op ("reshape" )(den )
645- added = is_op ("add" )(self .bias , re_den )
646- self .pattern = added
647-
648- def get_attr (self , pre ):
649- def visit_func (expr ):
650- if isinstance (expr , _expr .Call ) and expr .op == relay .op .get ("reshape" ):
651- new_attrs = {}
652- for k in expr .attrs .keys ():
653- new_attrs [k ] = expr .attrs [k ]
654- self .attr_map ["reshape" ] = new_attrs
655- elif isinstance (expr , _expr .Call ) and expr .op == relay .op .get ("nn.dense" ):
656- new_attrs = {}
657- for k in expr .attrs .keys ():
658- new_attrs [k ] = expr .attrs [k ]
659- self .attr_map ["nn.dense" ] = new_attrs
660-
661- _analysis .post_order_visit (pre , visit_func )
662-
663- def callback (self , pre , post , node_map ):
664- self .get_attr (pre )
665-
666- data = node_map [self .data ][0 ]
667- weight = node_map [self .weight ][0 ]
668- bias = node_map [self .bias ][0 ]
669-
670- den = relay .op .nn .dense (data , weight )
671- added = relay .op .add (bias , den )
672- return relay .op .reshape (added , self .attr_map ["reshape" ]["newshape" ])
673-
674-
675641def rewrite_dense_bias_reshape_last (mod ):
676642 """Rewrite the input graph to reorder reshape operators so that
677643 we can perform dense_bias fusion and then offload them to byoc part.
678644 """
679- mod ["main" ] = rewrite (DenseReshapeBiasRewrite ( ), mod ["main" ])
645+ mod ["main" ] = rewrite (DenseReshapeBiasGeluRewrite ( has_gelu = False ), mod ["main" ])
680646 return mod
0 commit comments