Skip to content

Commit 72a3f7f

Browse files
committed
add oneflow change
1 parent 64d2f94 commit 72a3f7f

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

python/tvm/relay/frontend/oneflow.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ class Conv(OneFlowOpConverter):
274274
@classmethod
275275
def _impl_v1(cls, inputs, attrs, params):
276276
# The kernel is imported from model_dir_path, without the ".weight" logo, etc.
277-
# The data is obtained through the graph, its op contains "-input_"
278-
in_names = ["-input_"]
277+
# The data is obtained through the graph, its op contains "_input."
278+
in_names = ["_input."]
279279
kernel_names = [".weight"]
280280
for i in inputs:
281281
IN_NAMES = any(x in str(i) for x in in_names)
@@ -335,7 +335,7 @@ class ConvTranspose(OneFlowOpConverter):
335335

336336
@classmethod
337337
def _impl_v1(cls, inputs, attrs, params):
338-
in_names = ["-input_"]
338+
in_names = ["_input."]
339339
kernel_names = [".weight"]
340340
for i in inputs:
341341
IN_NAMES = any(x in str(i) for x in in_names)
@@ -470,7 +470,7 @@ def _impl_v1(cls, inputs, attrs, params):
470470
# sort the inputs
471471
sorted_inputs = copy.deepcopy(inputs)
472472
for i in inputs:
473-
IN_NAMES = "-input_" in str(i)
473+
IN_NAMES = "_input." in str(i)
474474
if IN_NAMES:
475475
sorted_inputs[0] = i
476476
elif "weight" in str(i) and not IN_NAMES:
@@ -521,7 +521,7 @@ def _impl_v1(cls, inputs, attrs, params):
521521
assert len(inputs) == 2, "Gemm op take 2 inputs, {} given".format(len(inputs))
522522
# Similar to 'class Conv'
523523
true_names = ["weight"]
524-
false_names = ["-input_"]
524+
false_names = ["_input."]
525525
for i in inputs:
526526
T_NAMES = any(x in str(i) for x in true_names)
527527
F_NAMES = any(x in str(i) for x in false_names)
@@ -607,7 +607,7 @@ def _impl_v1(cls, inputs, attrs, params):
607607
axis = int(attrs.get("axis", 0))
608608

609609
true_names = ["weight", "bias"]
610-
false_names = ["-input_"]
610+
false_names = ["_input."]
611611

612612
for i in inputs:
613613
T_NAMES = any(x in str(i) for x in true_names)
@@ -665,7 +665,7 @@ def _impl_v1(cls, inputs, attrs, params):
665665

666666
for i in inputs:
667667
T_NAMES = any([x in str(i) for x in beta_names])
668-
if T_NAMES and "-input_" not in str(i):
668+
if T_NAMES and "_input." not in str(i):
669669
input_b = i
670670
else:
671671
input_a = i
@@ -923,7 +923,7 @@ class PReLU(OneFlowOpConverter):
923923
def _impl_v1(cls, inputs, attrs, params):
924924
assert len(inputs) == 2, "PReLU need 2 inputs, but {} given".format(len(inputs))
925925
for i in inputs:
926-
if "-input_" in str(i):
926+
if "_input." in str(i):
927927
prelu_a = i
928928
else:
929929
prelu_b = i
@@ -1376,7 +1376,7 @@ def deal_with_input_convert(
13761376
if node_input not in _nodes:
13771377
if (
13781378
node_path not in _input_path_2_name
1379-
or "-input_" in node_input
1379+
or "_input." in node_input
13801380
or "FreeEagerTensor" in node_input
13811381
):
13821382
_nodes[node_input] = new_var(
@@ -1430,8 +1430,8 @@ class OneflowGraph(object):
14301430
node name:
14311431
1. param: m.layer4.1.bn1.weight / ...
14321432
2. buffer: m.layer4.1.bn1.running_mean / ...
1433-
3. node inputs: m.layer4.1.bn1-input_0
1434-
4. node outputs: m.layer4.1.bn1-output_0
1433+
3. node inputs: m.layer4.1.bn1_input.0
1434+
4. node outputs: m.layer4.1.bn1_output.0
14351435
"""
14361436

14371437
def __init__(self, shape, dtype, nodes, model_dir_path):
@@ -1521,16 +1521,19 @@ def _parse_input(self, node, model_dir_path):
15211521

15221522
def _parse_output(self, op_name, outputs, cnt_init=0):
15231523
"""
1524-
o: m.classifier.1-output_xxx
1524+
o: m.classifier.1_output.xxx
15251525
new_o: m.classifier.1-conv2d_0
1526-
"_"+new_o is in self._shape
1526+
"_"+new_o_xxx is in self._shape
15271527
"""
15281528
for o in outputs:
1529-
if "-output_" not in o:
1530-
new_o = o.replace("-" + op_name, "-output")
1531-
new_o = new_o.replace("_" + new_o.split("_")[-1], "_0")
1532-
self._shape[o] = self._shape["_" + new_o]
1533-
self._dtype[o] = self._dtype["_" + new_o]
1529+
if "_output." not in o:
1530+
new_o = o.replace("-" + op_name, "_output")
1531+
new_o = new_o.replace("-" + new_o.split("-")[-1], ".0")
1532+
for k in self._shape.keys():
1533+
if new_o in k:
1534+
self._shape[o] = self._shape[k]
1535+
self._dtype[o] = self._dtype[k]
1536+
break
15341537
elif len(outputs) > 1:
15351538
outputs.remove(o)
15361539
if op_name.lower() == "dropout":
@@ -1557,15 +1560,15 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non
15571560
If freeze_params is True,
15581561
the computational graph input is the input of the first layer of the network,
15591562
which cannot be specified by the user, e.g.
1560-
Default input is: %v_ResNetGraph_0-input_0: Tensor[(1, 3, 224, 224), float32]
1561-
User-defined input is: %_0-input_0: Tensor[(1, 3, 640, 480), float32]
1563+
Default input is: %v_ResNetGraph_0_input.0: Tensor[(1, 3, 224, 224), float32]
1564+
User-defined input is: %_0_input.0: Tensor[(1, 3, 640, 480), float32]
15621565
If freeze_params is on, then conv1-in will be the graph input, not Input_0
15631566
user_input: dict
15641567
User-defined input information for the graph
15651568
{
15661569
node1_name:
15671570
{
1568-
'name': node1_name, # str, like "%v_ResNetGraph_0-input_0"
1571+
'name': node1_name, # str, like "%v_ResNetGraph_0_input.0"
15691572
'shape': node1_shape, # tuple
15701573
'dtype': node1_dtype # str, like "float32"
15711574
}
@@ -1584,9 +1587,9 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non
15841587
# step 1: get the graph input
15851588
if not freeze_params:
15861589
for node_init_name in user_input:
1587-
if "-input_" not in node_init_name:
1590+
if "_input." not in node_init_name:
15881591
raise KeyError(
1589-
"user_input['name'] should contain '-input_' "
1592+
"user_input['name'] should contain '_input.' "
15901593
+ "to let program know that this is input node"
15911594
)
15921595
self._nodes[node_init_name] = new_var(
@@ -1693,19 +1696,12 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non
16931696
nodes = {v: k for k, v in self._nodes.items()}
16941697
free_vars = [nodes[var] for var in free_vars]
16951698

1696-
# step 6: make sure the '-input_0' is the first in self._inputs
1699+
# step 6: make sure the '_input.0' is the first in self._inputs
16971700
for free_var in free_vars:
16981701
if free_var not in self._inputs:
16991702
self._inputs[free_var] = self._nodes[free_var]
17001703

17011704
input_names = list(self._inputs.keys())
1702-
for i, _ in enumerate(input_names):
1703-
if i != 0 and "-input_0" in input_names[i]:
1704-
str_buffer = copy.deepcopy(input_names[i])
1705-
del input_names[i]
1706-
input_names.insert(0, str_buffer)
1707-
break
1708-
17091705
for input_name in input_names:
17101706
if input_name in self._inputs:
17111707
self._sort_inputs[input_name] = self._inputs[input_name]

0 commit comments

Comments
 (0)