@@ -93,7 +93,7 @@ def _weight_shape_match_transpose(data, dshape, channels, cfactor_out):
9393 if pad_width != 0 :
9494 pad_width = cfactor_out - pad_width
9595 data = op .nn .pad (data , [[0 , 0 ], [0 , pad_width ], [0 , 0 ], [0 , 0 ]])
96- dshape = tuple ([ dshape [0 ]] + [dshape [1 ] + pad_width , dshape [2 ], dshape [3 ]])
96+ dshape = tuple (dshape [0 ], [dshape [1 ] + pad_width , dshape [2 ], dshape [3 ]])
9797
9898 if channels_pad != 0 :
9999 channels = channels + (cfactor_out - channels_pad )
@@ -174,104 +174,6 @@ def _operator_idx_inc(expr, count_meta, operator_current_idx):
174174 operator_current_idx = operator_current_idx + 1
175175 return operator_current_idx
176176
177-
178- class ExprDeviceAnnot (ExprMutator ):
179- """Visitor to perform graph annotation on an AST.
180-
181- Parameters
182- ----------
183- start: int
184- the start location to mark run on vta (inclusive)
185- end: int
186- the end location to mark run on vta (exclusive)
187-
188- Returns
189- ---------
190- None
191- """
192- def __init__ (self , start = - 1 , end = - 1 ):
193- self .ext_ctx = tvm .context ("ext_dev" )
194- self .cpu_ctx = tvm .context ("cpu" )
195- self .cast = op .op .get ("cast" )
196- self .counter = - 1
197- self .start = start
198- self .end = end
199- super ().__init__ ()
200-
201- def visit_call (self , call ):
202- """ Visit the children. """
203- # First visit the children.
204- oshape = _get_tensor_shape (call )
205- odtype = _get_tensor_type (call )
206- input_types = [arg .checked_type for arg in call .args ]
207- args = [self .visit (arg ) for arg in call .args ]
208-
209- self .counter += 1
210- if self .counter == self .start :
211- ret = relay .Call (call .op , args , call .attrs )
212- ret = relay .annotation .on_device (ret , self .ext_ctx )
213- return ret
214- elif self .counter == self .end :
215- ret = relay .Call (call .op , args , call .attrs )
216- ret = relay .annotation .on_device (ret , self .cpu_ctx )
217- return ret
218- elif self .counter > self .start and self .counter < self .end :
219- ret = relay .Call (call .op , args , call .attrs )
220-
221- # skip the float op, i.e., float->int cast
222- if self .is_float_op (call ):
223- return ret
224-
225- return relay .annotation .on_device (ret , self .ext_ctx )
226-
227- return relay .Call (self .visit (call .op ), args , call .attrs )
228-
229- def is_float_op (self , call ):
230- """check if this op belongs to a float op
231- in general, float op's odtype is float;
232- a special case is float->int cast, which follow this op sequence:
233- multiply(float) -> round(float) -> clip(float) -> cast(int);
234- """
235- args = call .args
236- odtype = _get_tensor_type (call )
237- op = call .op
238-
239- if odtype == "float32" :
240- return True
241- elif op == self .cast :
242- idtype = _get_tensor_type (args [0 ])
243- if idtype == "float32" :
244- return True
245-
246- return False
247-
248-
249- class ExprLocater (ExprMutator ):
250- """Visitor to locate op on an AST.
251- """
252- def __init__ (self ):
253- self .counter = - 1
254- self .op2nodes = {}
255- super ().__init__ ()
256-
257- def visit_call (self , call ):
258- """ Visit the children. """
259- # First visit the children.
260- args = [self .visit (arg ) for arg in call .args ]
261-
262- odtype = _get_tensor_type (call )
263- self .counter += 1
264- if (call .op , odtype ) in self .op2nodes :
265- self .op2nodes [(call .op , odtype )].append (self .counter )
266- else :
267- self .op2nodes [(call .op , odtype )] = [self .counter ]
268-
269- return relay .Call (
270- self .visit (call .op ),
271- args ,
272- call .attrs )
273-
274-
275177class ExprPack (ExprMutator ):
276178 """Visitor to perform graph packing on an AST.
277179 """
@@ -415,7 +317,7 @@ def visit_call(self, call):
415317 elif self .start_pack and call .op == op .op .get ('cast' ) and \
416318 input_types [0 ].dtype == 'int32' :
417319 cast = relay .Call (op .op .get ('cast' ), [args [0 ]], call .attrs )
418- return cast
320+ return relay . Call ( op . op . get ( 'copy' ), [ cast ])
419321 elif call .op == self .pad :
420322 pad_width = call .attrs .pad_width
421323 if len (pad_width ) == 6 :
@@ -510,10 +412,7 @@ def graph_pack(expr,
510412 stop_name = "nn.global_avg_pool2d" ,
511413 start_name_idx = None ,
512414 stop_name_idx = None ,
513- count_meta = False ,
514- device_annot = False ,
515- annot_start_name = "nn.conv2d" ,
516- annot_end_name = "annotation.stop_fusion" ):
415+ count_meta = False ):
517416 """Pack the graph into batch&channel packed format.
518417
519418 Parameters
@@ -550,47 +449,18 @@ def graph_pack(expr,
550449 'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
551450 logic would count the meta.
552451
553- device_annot: boolean, optional
554- if we want to annoate the device_type
555-
556- annot_start_name: str, optional
557- device annotation start node, from which we mark the nodes as `ext_dev`
558-
559- annot_end_name: str, optional
560- device annotation end node, after which we mark the nodes as 'cpu'
561-
562452 Returns
563453 -------
564454 expr : Expr
565455 The transformed expression.
566456 """
567457 assert isinstance (expr , relay .Function )
568- assert ((start_name != stop_name ) or (start_name_idx is None != stop_name_idx is None ) or \
569- (not (start_name_idx is None and stop_name_idx is None )) or (start_name_idx < stop_name_idx ))
458+ assert ((start_name != stop_name ) or (start_name_idx < stop_name_idx ))
570459 expr = get_subgraph (expr , start_name , stop_name , start_name_idx , stop_name_idx , count_meta )
571460 expr = run_opt_pass (expr , transform .InferType ())
572461 packer = ExprPack (
573462 bfactor , cfactor ,
574463 weight_bits )
575464 expr = packer .visit (expr )
576465 assert not packer .start_pack
577- expr = run_opt_pass (expr , transform .InferType ())
578-
579- if device_annot :
580- expr_locator = ExprLocater ()
581- expr_locator .visit (expr )
582-
583- annot_start = op .op .get (annot_start_name )
584- start = expr_locator .op2nodes [(annot_start , "int32" )][0 ]
585-
586- annot_end = op .op .get (annot_end_name )
587- # we mark the next op to the last stop_fusion on cpu device
588- end = expr_locator .op2nodes [(annot_end , "int8" )][- 1 ] + 1
589-
590- device_annot = ExprDeviceAnnot (start = start , end = end )
591- expr = device_annot .visit (expr )
592- ret = run_opt_pass (expr , transform .InferType ())
593-
594- return ret
595- else :
596- return expr
466+ return run_opt_pass (expr , transform .InferType ())
0 commit comments