@@ -938,8 +938,6 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
938938 'Split' : _split (False ),
939939 'SplitV' : _split (True ),
940940 'Unpack' : _unpack (),
941- 'QueueDequeueManyV2' : _undef (),
942- 'FIFOQueueV2' : _undef (),
943941}
944942
945943# _convert_map_rnn defines maps of rnn operator name to
@@ -1184,42 +1182,59 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
11841182 if missing_operators :
11851183 raise NotImplementedError ( \
11861184 "The following operators are not implemented: {}" .format (missing_operators ))
1185+
11871186 for node in graph .node :
11881187 if node .op == 'Placeholder' :
1189- self ._input_shapes [node .name ] = tensor_util .TensorShapeProtoToList (node .attr ['shape' ].shape )
1190- self ._input_shapes [node .name ][0 ] = 1
1188+ self ._input_shapes [node .name ] = \
1189+ tensor_util .TensorShapeProtoToList (node .attr ['shape' ].shape )
1190+
1191+ if shape and node .name in shape :
1192+ for idx , dim in enumerate (self ._input_shapes [node .name ]):
1193+ if dim < 0 :
1194+ self ._input_shapes [node .name ][idx ] = shape [node .name ][idx ]
1195+ if not self ._input_shapes [node .name ]:
1196+ self ._input_shapes [node .name ] = list (shape [node .name ])
1197+ assert self ._input_shapes [node .name ] == list (shape [node .name ])
1198+ else :
1199+ for idx , dim in enumerate (self ._input_shapes [node .name ]):
1200+ if dim < 0 :
1201+ self ._input_shapes [node .name ][idx ] = 1
1202+
11911203 elif node .op == 'Const' :
11921204 tensor_value = node .attr ['value' ].tensor
1193- self ._input_shapes [node .name ] = tensor_util .TensorShapeProtoToList (tensor_value .tensor_shape )
1205+ self ._input_shapes [node .name ] = \
1206+ tensor_util .TensorShapeProtoToList (tensor_value .tensor_shape )
1207+ if self ._input_shapes [node .name ] and shape and node .name in shape :
1208+ assert self ._input_shapes [node .name ] == list (shape [node .name ])
11941209
11951210 final_op = None
11961211 # Parse the nodes to re-create TF graph using Symbol API of NNVM
11971212 for node in graph .node :
1198- # Tensorflow doesn't have seperate list for params extraction.
1213+ # Tensorflow doesn't have separate list for params extraction.
11991214 # Operator name 'Const' is treated as a parameter to build NNVM params dict.
12001215
12011216 input_shapes = {}
12021217 input_0d_mismatch = set ()
12031218 attr = self ._parse_attr (node .attr )
12041219
1205- #Variable converted to Const will not have only value attr
1220+ # Variable converted to Const will not have only value attr
12061221 if 'value' in attr and node .op == 'Const' :
12071222 self ._output_shapes [node .name ] = [self ._input_shapes [node .name ]]
1223+ elif shape and node .name in shape :
1224+ # Give priority to user argument.
1225+ self ._output_shapes [node .name ] = [shape [node .name ]]
12081226 elif node .op == 'Placeholder' :
12091227 self ._output_shapes [node .name ] = [self ._input_shapes [node .name ]]
1210- elif shape and node .name in shape :
1211- # Give priority to user argument.
1212- self ._output_shapes [node .name ] = [shape [node .name ]]
12131228 elif '_output_shapes' in attr :
12141229 self ._output_shapes [node .name ] = \
12151230 [tensor_util .TensorShapeProtoToList (tshape ) \
12161231 for tshape in attr ['_output_shapes' ]]
1217- elif shape :
1232+ else :
12181233 # Keep the list indexable to avoid key error.
12191234 # Actual value will be filled after node creation.
1235+ # Will infer shapes if the graph is not frozen with add_shapes=True
12201236 self ._output_shapes [node .name ] = [None ]
1221- else :
1222- self ._output_shapes [node .name ] = None
1237+
12231238 self ._outputs_are_0d [node .name ] = [ \
12241239 not tshape if isinstance (tshape , list ) else False \
12251240 for tshape in self ._output_shapes [node .name ]]
@@ -1241,7 +1256,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12411256
12421257 else :
12431258 # Pass the parsed shapes instead
1244- output_shapes = self ._output_shapes [node .name ]
1259+ attr [ "_output_shapes" ] = output_shapes = self ._output_shapes [node .name ]
12451260
12461261 # Pass the node name too in attr
12471262 attr ["_node_name" ] = node .name
@@ -1291,20 +1306,27 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12911306 # Assuming only one output.
12921307 self ._nodes [node .name ] = op
12931308 final_op = op
1294- # Infer shapes if passed explicitely
1295- node_output = self ._nodes [node .name ]
1296- if shape :
1297- g = _graph .create (node_output )
1298- shape_dict = {k : v .shape for k , v in self ._params .items ()}
1299- shape_dict .update (shape )
1300- _ , out_shapes = graph_util .infer_shape (g , ** shape_dict )
1301- self ._output_shapes [node .name ] = out_shapes
1302- elif output_shapes == None :
1303- g = _graph .create (node_output )
1304- self ._output_shapes [node .name ] = list (graph_util .infer_shape (g , ** self ._input_shapes ))[- 1 ]
1309+
1310+ # Infer shapes even without specifying "add_shapes=True"
1311+ if output_shapes == [None ]:
1312+ g = _graph .create (final_op )
1313+ self ._output_shapes [node .name ] = \
1314+ list (graph_util .infer_shape (g , ** self ._input_shapes ))[- 1 ]
13051315 else :
13061316 self ._output_shapes [node .name ] = output_shapes
13071317
1318+ if self ._output_shapes [node .name ] and shape and node .name in shape :
1319+ assert self ._input_shapes [node .name ] == list (shape [node .name ])
1320+
1321+ # Infer shapes if passed explicitely
1322+ node_output = self ._nodes [node .name ]
1323+ if shape :
1324+ g = _graph .create (node_output )
1325+ shape_dict = {k : v .shape for k , v in self ._params .items ()}
1326+ shape_dict .update (shape )
1327+ _ , out_shapes = graph_util .infer_shape (g , ** shape_dict )
1328+ self ._output_shapes [node .name ] = out_shapes
1329+
13081330 out = []
13091331 if outputs is None :
13101332 out .append (final_op )
0 commit comments