@@ -274,8 +274,8 @@ class Conv(OneFlowOpConverter):
274274 @classmethod
275275 def _impl_v1 (cls , inputs , attrs , params ):
276276 # The kernel is imported from model_dir_path, without the ".weight" logo, etc.
277- # The data is obtained through the graph, its op contains "-input_ "
278- in_names = ["-input_ " ]
277+ # The data is obtained through the graph, its op contains "_input. "
278+ in_names = ["_input. " ]
279279 kernel_names = [".weight" ]
280280 for i in inputs :
281281 IN_NAMES = any (x in str (i ) for x in in_names )
@@ -335,7 +335,7 @@ class ConvTranspose(OneFlowOpConverter):
335335
336336 @classmethod
337337 def _impl_v1 (cls , inputs , attrs , params ):
338- in_names = ["-input_ " ]
338+ in_names = ["_input. " ]
339339 kernel_names = [".weight" ]
340340 for i in inputs :
341341 IN_NAMES = any (x in str (i ) for x in in_names )
@@ -470,7 +470,7 @@ def _impl_v1(cls, inputs, attrs, params):
470470 # sort the inputs
471471 sorted_inputs = copy .deepcopy (inputs )
472472 for i in inputs :
473- IN_NAMES = "-input_ " in str (i )
473+ IN_NAMES = "_input. " in str (i )
474474 if IN_NAMES :
475475 sorted_inputs [0 ] = i
476476 elif "weight" in str (i ) and not IN_NAMES :
@@ -521,7 +521,7 @@ def _impl_v1(cls, inputs, attrs, params):
521521 assert len (inputs ) == 2 , "Gemm op take 2 inputs, {} given" .format (len (inputs ))
522522 # Similar to 'class Conv'
523523 true_names = ["weight" ]
524- false_names = ["-input_ " ]
524+ false_names = ["_input. " ]
525525 for i in inputs :
526526 T_NAMES = any (x in str (i ) for x in true_names )
527527 F_NAMES = any (x in str (i ) for x in false_names )
@@ -607,7 +607,7 @@ def _impl_v1(cls, inputs, attrs, params):
607607 axis = int (attrs .get ("axis" , 0 ))
608608
609609 true_names = ["weight" , "bias" ]
610- false_names = ["-input_ " ]
610+ false_names = ["_input. " ]
611611
612612 for i in inputs :
613613 T_NAMES = any (x in str (i ) for x in true_names )
@@ -665,7 +665,7 @@ def _impl_v1(cls, inputs, attrs, params):
665665
666666 for i in inputs :
667667 T_NAMES = any ([x in str (i ) for x in beta_names ])
668- if T_NAMES and "-input_ " not in str (i ):
668+ if T_NAMES and "_input. " not in str (i ):
669669 input_b = i
670670 else :
671671 input_a = i
@@ -923,7 +923,7 @@ class PReLU(OneFlowOpConverter):
923923 def _impl_v1 (cls , inputs , attrs , params ):
924924 assert len (inputs ) == 2 , "PReLU need 2 inputs, but {} given" .format (len (inputs ))
925925 for i in inputs :
926- if "-input_ " in str (i ):
926+ if "_input. " in str (i ):
927927 prelu_a = i
928928 else :
929929 prelu_b = i
@@ -1376,7 +1376,7 @@ def deal_with_input_convert(
13761376 if node_input not in _nodes :
13771377 if (
13781378 node_path not in _input_path_2_name
1379- or "-input_ " in node_input
1379+ or "_input. " in node_input
13801380 or "FreeEagerTensor" in node_input
13811381 ):
13821382 _nodes [node_input ] = new_var (
@@ -1430,8 +1430,8 @@ class OneflowGraph(object):
14301430 node name:
14311431 1. param: m.layer4.1.bn1.weight / ...
14321432 2. buffer: m.layer4.1.bn1.running_mean / ...
1433- 3. node inputs: m.layer4.1.bn1-input_0
1434- 4. node outputs: m.layer4.1.bn1-output_0
1433+ 3. node inputs: m.layer4.1.bn1_input.0
1434+ 4. node outputs: m.layer4.1.bn1_output.0
14351435 """
14361436
14371437 def __init__ (self , shape , dtype , nodes , model_dir_path ):
@@ -1521,16 +1521,19 @@ def _parse_input(self, node, model_dir_path):
15211521
15221522 def _parse_output (self , op_name , outputs , cnt_init = 0 ):
15231523 """
1524- o: m.classifier.1-output_xxx
1524+ o: m.classifier.1_output.xxx
15251525 new_o: m.classifier.1-conv2d_0
1526- "_"+new_o is in self._shape
1526+ "_"+new_o_xxx is in self._shape
15271527 """
15281528 for o in outputs :
1529- if "-output_" not in o :
1530- new_o = o .replace ("-" + op_name , "-output" )
1531- new_o = new_o .replace ("_" + new_o .split ("_" )[- 1 ], "_0" )
1532- self ._shape [o ] = self ._shape ["_" + new_o ]
1533- self ._dtype [o ] = self ._dtype ["_" + new_o ]
1529+ if "_output." not in o :
1530+ new_o = o .replace ("-" + op_name , "_output" )
1531+ new_o = new_o .replace ("-" + new_o .split ("-" )[- 1 ], ".0" )
1532+ for k in self ._shape .keys ():
1533+ if new_o in k :
1534+ self ._shape [o ] = self ._shape [k ]
1535+ self ._dtype [o ] = self ._dtype [k ]
1536+ break
15341537 elif len (outputs ) > 1 :
15351538 outputs .remove (o )
15361539 if op_name .lower () == "dropout" :
@@ -1557,15 +1560,15 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non
15571560 If freeze_params is True,
15581561 the computational graph input is the input of the first layer of the network,
15591562 which cannot be specified by the user, e.g.
1560- Default input is: %v_ResNetGraph_0-input_0 : Tensor[(1, 3, 224, 224), float32]
1561- User-defined input is: %_0-input_0 : Tensor[(1, 3, 640, 480), float32]
1563+ Default input is: %v_ResNetGraph_0_input.0 : Tensor[(1, 3, 224, 224), float32]
1564+ User-defined input is: %_0_input.0 : Tensor[(1, 3, 640, 480), float32]
15621565 If freeze_params is on, then conv1-in will be the graph input, not Input_0
15631566 user_input: dict
15641567 User-defined input information for the graph
15651568 {
15661569 node1_name:
15671570 {
1568- 'name': node1_name, # str, like "%v_ResNetGraph_0-input_0 "
1571+ 'name': node1_name, # str, like "%v_ResNetGraph_0_input.0 "
15691572 'shape': node1_shape, # tuple
15701573 'dtype': node1_dtype # str, like "float32"
15711574 }
@@ -1584,9 +1587,9 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non
15841587 # step 1: get the graph input
15851588 if not freeze_params :
15861589 for node_init_name in user_input :
1587- if "-input_ " not in node_init_name :
1590+ if "_input. " not in node_init_name :
15881591 raise KeyError (
1589- "user_input['name'] should contain '-input_ ' "
1592+ "user_input['name'] should contain '_input. ' "
15901593 + "to let program know that this is input node"
15911594 )
15921595 self ._nodes [node_init_name ] = new_var (
@@ -1693,19 +1696,12 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non
16931696 nodes = {v : k for k , v in self ._nodes .items ()}
16941697 free_vars = [nodes [var ] for var in free_vars ]
16951698
1696- # step 6: make sure the '-input_0 ' is the first in self._inputs
1699+ # step 6: make sure the '_input.0 ' is the first in self._inputs
16971700 for free_var in free_vars :
16981701 if free_var not in self ._inputs :
16991702 self ._inputs [free_var ] = self ._nodes [free_var ]
17001703
17011704 input_names = list (self ._inputs .keys ())
1702- for i , _ in enumerate (input_names ):
1703- if i != 0 and "-input_0" in input_names [i ]:
1704- str_buffer = copy .deepcopy (input_names [i ])
1705- del input_names [i ]
1706- input_names .insert (0 , str_buffer )
1707- break
1708-
17091705 for input_name in input_names :
17101706 if input_name in self ._inputs :
17111707 self ._sort_inputs [input_name ] = self ._inputs [input_name ]
0 commit comments