@@ -201,21 +201,20 @@ def __init__(self, start=-1, end=-1):
201201 def visit_call (self , call ):
202202 """ Visit the children. """
203203 # First visit the children.
204- oshape = _get_tensor_shape (call )
205- odtype = _get_tensor_type (call )
206- input_types = [arg .checked_type for arg in call .args ]
207204 args = [self .visit (arg ) for arg in call .args ]
208205
209206 self .counter += 1
210207 if self .counter == self .start :
211208 ret = relay .Call (call .op , args , call .attrs )
212209 ret = relay .annotation .on_device (ret , self .ext_ctx )
213210 return ret
214- elif self .counter == self .end :
211+
212+ if self .counter == self .end :
215213 ret = relay .Call (call .op , args , call .attrs )
216214 ret = relay .annotation .on_device (ret , self .cpu_ctx )
217215 return ret
218- elif self .counter > self .start and self .counter < self .end :
216+
217+ if self .counter > self .start and self .counter < self .end :
219218 ret = relay .Call (call .op , args , call .attrs )
220219
221220 # skip the float op, i.e., float->int cast
@@ -234,11 +233,11 @@ def is_float_op(self, call):
234233 """
235234 args = call .args
236235 odtype = _get_tensor_type (call )
237- op = call .op
238236
239237 if odtype == "float32" :
240238 return True
241- elif op == self .cast :
239+
240+ if call .op == self .cast :
242241 idtype = _get_tensor_type (args [0 ])
243242 if idtype == "float32" :
244243 return True
@@ -566,7 +565,8 @@ def graph_pack(expr,
566565 """
567566 assert isinstance (expr , relay .Function )
568567 assert ((start_name != stop_name ) or (start_name_idx is None != stop_name_idx is None ) or \
569- (not (start_name_idx is None and stop_name_idx is None )) or (start_name_idx < stop_name_idx ))
568+ (not (start_name_idx is None and stop_name_idx is None )) \
569+ or (start_name_idx < stop_name_idx ))
570570 expr = get_subgraph (expr , start_name , stop_name , start_name_idx , stop_name_idx , count_meta )
571571 expr = run_opt_pass (expr , transform .InferType ())
572572 packer = ExprPack (
@@ -589,8 +589,6 @@ def graph_pack(expr,
589589
590590 device_annot = ExprDeviceAnnot (start = start , end = end )
591591 expr = device_annot .visit (expr )
592- ret = run_opt_pass (expr , transform .InferType ())
592+ return run_opt_pass (expr , transform .InferType ())
593593
594- return ret
595- else :
596- return expr
594+ return expr
0 commit comments