Skip to content

Commit 9f2c4d1

Browse files
committed
Get tags of saved model automatically
Remove exception trail in tf parser error message Fix lint Fix comments
1 parent ec1000e commit 9f2c4d1

File tree

2 files changed

+74
-75
lines changed

2 files changed

+74
-75
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,6 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
938938
'Split' : _split(False),
939939
'SplitV' : _split(True),
940940
'Unpack' : _unpack(),
941-
'QueueDequeueManyV2' : _undef(),
942-
'FIFOQueueV2' : _undef(),
943941
}
944942

945943
# _convert_map_rnn defines maps of rnn operator name to
@@ -1184,42 +1182,59 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
11841182
if missing_operators:
11851183
raise NotImplementedError( \
11861184
"The following operators are not implemented: {}".format(missing_operators))
1185+
11871186
for node in graph.node:
11881187
if node.op == 'Placeholder':
1189-
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
1190-
self._input_shapes[node.name][0] = 1
1188+
self._input_shapes[node.name] = \
1189+
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
1190+
1191+
if shape and node.name in shape:
1192+
for idx, dim in enumerate(self._input_shapes[node.name]):
1193+
if dim < 0:
1194+
self._input_shapes[node.name][idx] = shape[node.name][idx]
1195+
if not self._input_shapes[node.name]:
1196+
self._input_shapes[node.name] = list(shape[node.name])
1197+
assert self._input_shapes[node.name] == list(shape[node.name])
1198+
else:
1199+
for idx, dim in enumerate(self._input_shapes[node.name]):
1200+
if dim < 0:
1201+
self._input_shapes[node.name][idx] = 1
1202+
11911203
elif node.op == 'Const':
11921204
tensor_value = node.attr['value'].tensor
1193-
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
1205+
self._input_shapes[node.name] = \
1206+
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
1207+
if self._input_shapes[node.name] and shape and node.name in shape:
1208+
assert self._input_shapes[node.name] == list(shape[node.name])
11941209

11951210
final_op = None
11961211
# Parse the nodes to re-create TF graph using Symbol API of NNVM
11971212
for node in graph.node:
1198-
# Tensorflow doesn't have seperate list for params extraction.
1213+
# Tensorflow doesn't have separate list for params extraction.
11991214
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
12001215

12011216
input_shapes = {}
12021217
input_0d_mismatch = set()
12031218
attr = self._parse_attr(node.attr)
12041219

1205-
#Variable converted to Const will not have only value attr
1220+
# Variable converted to Const will not have only value attr
12061221
if 'value' in attr and node.op == 'Const':
12071222
self._output_shapes[node.name] = [self._input_shapes[node.name]]
1223+
elif shape and node.name in shape:
1224+
# Give priority to user argument.
1225+
self._output_shapes[node.name] = [shape[node.name]]
12081226
elif node.op == 'Placeholder':
12091227
self._output_shapes[node.name] = [self._input_shapes[node.name]]
1210-
elif shape and node.name in shape:
1211-
# Give priority to user argument.
1212-
self._output_shapes[node.name] = [shape[node.name]]
12131228
elif '_output_shapes' in attr:
12141229
self._output_shapes[node.name] = \
12151230
[tensor_util.TensorShapeProtoToList(tshape) \
12161231
for tshape in attr['_output_shapes']]
1217-
elif shape:
1232+
else:
12181233
# Keep the list indexable to avoid key error.
12191234
# Actual value will be filled after node creation.
1235+
# Will infer shapes if the graph is not frozen with add_shapes=True
12201236
self._output_shapes[node.name] = [None]
1221-
else:
1222-
self._output_shapes[node.name] = None
1237+
12231238
self._outputs_are_0d[node.name] = [ \
12241239
not tshape if isinstance(tshape, list) else False \
12251240
for tshape in self._output_shapes[node.name]]
@@ -1241,7 +1256,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12411256

12421257
else:
12431258
# Pass the parsed shapes instead
1244-
output_shapes = self._output_shapes[node.name]
1259+
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
12451260

12461261
# Pass the node name too in attr
12471262
attr["_node_name"] = node.name
@@ -1291,20 +1306,27 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12911306
# Assuming only one output.
12921307
self._nodes[node.name] = op
12931308
final_op = op
1294-
# Infer shapes if passed explicitely
1295-
node_output = self._nodes[node.name]
1296-
if shape:
1297-
g = _graph.create(node_output)
1298-
shape_dict = {k: v.shape for k, v in self._params.items()}
1299-
shape_dict.update(shape)
1300-
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
1301-
self._output_shapes[node.name] = out_shapes
1302-
elif output_shapes == None:
1303-
g = _graph.create(node_output)
1304-
self._output_shapes[node.name] = list(graph_util.infer_shape(g, **self._input_shapes))[-1]
1309+
1310+
# Infer shapes even without specifying "add_shapes=True"
1311+
if output_shapes == [None]:
1312+
g = _graph.create(final_op)
1313+
self._output_shapes[node.name] = \
1314+
list(graph_util.infer_shape(g, **self._input_shapes))[-1]
13051315
else:
13061316
self._output_shapes[node.name] = output_shapes
13071317

1318+
if self._output_shapes[node.name] and shape and node.name in shape:
1319+
assert self._input_shapes[node.name] == list(shape[node.name])
1320+
1321+
# Infer shapes if passed explicitely
1322+
node_output = self._nodes[node.name]
1323+
if shape:
1324+
g = _graph.create(node_output)
1325+
shape_dict = {k: v.shape for k, v in self._params.items()}
1326+
shape_dict.update(shape)
1327+
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
1328+
self._output_shapes[node.name] = out_shapes
1329+
13081330
out = []
13091331
if outputs is None:
13101332
out.append(final_op)

nnvm/python/nnvm/frontend/util/tensorflow_parser.py

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,13 @@
22
from __future__ import absolute_import as _abs
33
from __future__ import print_function
44
import os
5-
6-
try:
7-
from tensorflow.core.framework import graph_pb2
8-
except ImportError as e:
9-
from nnvm.frontend.protobuf import graph_pb2
10-
11-
12-
try:
13-
from tempfile import TemporaryDirectory
14-
except ImportError:
15-
import tempfile
16-
import shutil
17-
18-
class TemporaryDirectory(object):
19-
def __enter__(self):
20-
self.name = tempfile.mkdtemp()
21-
return self.name
22-
23-
def __exit__(self, exc, value, tb):
24-
shutil.rmtree(self.name)
5+
from tensorflow.core.framework import graph_pb2
6+
from tvm.contrib import util
257

268

279
class TFParser(object):
2810
"""A Wrapper to handle tensorflow models parsing
29-
Works w/o installing tensorflow,
30-
Protocol Buffer is needed
11+
TensorFlow is needed
3112
```
3213
parser = TfParser(model_dir)
3314
graph = parser.parse()
@@ -39,7 +20,7 @@ class TFParser(object):
3920
"""
4021

4122
def __init__(self, model_dir):
42-
self._tmp_dir = TemporaryDirectory()
23+
self._tmp_dir = util.tempdir()
4324
self._model_dir = model_dir
4425
self._graph = graph_pb2.GraphDef()
4526

@@ -51,41 +32,37 @@ def _get_graph(self):
5132
"""Get Graph"""
5233
return self._graph
5334

54-
def _output_graph(self):
55-
import logging
56-
logging.basicConfig(level=logging.DEBUG)
57-
for node in self._get_graph().node:
58-
logging.info("Name: {}".format(node.name))
59-
logging.info("\top: {}".format(node.op))
60-
for input in node.input:
61-
logging.info("\t\tinput: {}".format(input))
62-
logging.info("\t\tdevice: {}".format(node.device))
63-
logging.info("\t\tAttrValue: ")
64-
for key in node.attr.keys():
65-
logging.info("\t\t\tkey: {} => value: {}"
66-
.format(key, node.attr[key]))
67-
logging.info(node.attr['shape'].shape)
68-
6935
def _load_pb_file(self):
7036
"""Load single pb file"""
7137
graph = self._get_graph()
7238
with open(self._model_dir, "rb") as f:
7339
graph.ParseFromString(f.read())
7440
return graph
7541

76-
def _get_output_names(self, model_path):
42+
def _get_tag_set(self):
43+
"""Return the tag set of saved model, multiple metagraphs are not supported"""
44+
try:
45+
from tensorflow.contrib.saved_model.python.saved_model import reader
46+
except ImportError:
47+
raise ImportError(
48+
"InputConfiguration: Unable to import saved_model.reader which is "
49+
"required to get tag set from saved model.")
50+
tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
51+
return tag_sets[0]
52+
53+
def _get_output_names(self):
7754
"""Return the concatenated output names"""
7855
try:
7956
import tensorflow as tf
80-
except ImportError as e:
57+
except ImportError:
8158
raise ImportError(
8259
"InputConfiguration: Unable to import tensorflow which is "
83-
"required to restore from saved model. {}".format(e))
84-
60+
"required to restore from saved model.")
61+
tags = self._get_tag_set()
8562
with tf.Session() as sess:
8663
meta_graph_def = tf.saved_model.loader.load(sess,
87-
[tf.saved_model.tag_constants.SERVING],
88-
model_path)
64+
tags,
65+
self._model_dir)
8966
output_names = set()
9067
for k in meta_graph_def.signature_def.keys():
9168
outputs_tensor_info = meta_graph_def.signature_def[k].outputs
@@ -97,19 +74,18 @@ def _get_output_names(self, model_path):
9774
def _load_saved_model(self):
9875
"""Load the tensorflow saved model."""
9976
try:
100-
import tensorflow as tf
10177
from tensorflow.python.tools import freeze_graph
10278
from tensorflow.python.framework import ops
10379
from tensorflow.python.framework import graph_util
104-
except ImportError as e:
80+
except ImportError:
10581
raise ImportError(
10682
"InputConfiguration: Unable to import tensorflow which is "
107-
"required to restore from saved model. {}".format(e))
83+
"required to restore from saved model.")
10884

10985
saved_model_dir = self._model_dir
110-
output_graph_filename = os.path.join(self._tmp_dir.name, "neo_frozen_model.pb")
86+
output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
11187
input_saved_model_dir = saved_model_dir
112-
output_node_names = self._get_output_names(self._model_dir)
88+
output_node_names = self._get_output_names()
11389

11490
input_binary = False
11591
input_saver_def_path = False
@@ -119,7 +95,7 @@ def _load_saved_model(self):
11995
input_meta_graph = False
12096
checkpoint_path = None
12197
input_graph_filename = None
122-
saved_model_tags = tf.saved_model.tag_constants.SERVING
98+
saved_model_tags = ",".join(self._get_tag_set())
12399

124100
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
125101
input_binary, checkpoint_path, output_node_names,
@@ -145,6 +121,7 @@ def parse(self):
145121
file.
146122
"""
147123
graph = None
124+
148125
if os.path.isdir(self._model_dir):
149126
ckpt = os.path.join(self._model_dir, "checkpoint")
150127
if not os.path.isfile(ckpt):

0 commit comments

Comments
 (0)