1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17- # pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
17+ # pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except, too-many-nested-blocks
1818"""Tensorflow2.x graph to relay converter.
1919
2020If model is constructed using tf2.x API, then use this converter:
3838from .common import infer_type as _infer_type
3939
4040from .tensorflow_ops import _convert_map as _convert_map_common
41- from .tensorflow_ops import _need_prelude_for_shape_inference
41+ from .tensorflow_ops import _get_more_static_shape_rank
42+ from .tensorflow2_ops import _convert_map as _convert_map_tf2
43+ from .tensorflow2_ops import _need_prelude_for_shape_inference
4244
4345from ..ty import Any
4446
4547__all__ = ["from_tensorflow" ]
4648
49+ # A map to record tensor list write ops and input tl/tensor indices
50+ # Value is (index of tensor list, index of written node)
51+ _tensor_list_write_ops = {
52+ "TensorListSetItem" : (0 , 2 ),
53+ }
54+
4755
4856def _infer_type_with_prelude (val , prelude ):
4957 body = _infer_type (val , prelude .mod )
@@ -66,6 +74,11 @@ def set_span(sym, node_name):
6674 return sym
6775
6876
77+ def is_tensor_list_constuctor (tf_node ):
78+ """Check whether is tensor list constructor node."""
79+ return tf_node .op == "TensorListReserve"
80+
81+
6982def convert_const_node (node , shape ):
7083 """convert tf const node into relay const or var"""
7184
@@ -196,6 +209,10 @@ def __init__(self, module):
196209 self ._output_shapes = {}
197210 self ._tf_node_map = {}
198211 self ._gdef_lib = {}
212+ self ._tensor_list_shapes = {}
213+ self ._tensor_list_shape_nodes = {}
214+ self ._sub_map = {}
215+ self ._sub_input_idx_map = {}
199216
200217 def from_tensorflow (
201218 self , graph , layout = "NHWC" , shape = None , outputs = None , input_types = None , gdef_lib = None
@@ -215,10 +232,134 @@ def from_tensorflow(
215232 )
216233 return func , self ._params
217234
235+ def _analysis_tensor_list_op (
236+ self ,
237+ graph ,
238+ node ,
239+ tl_write_nodes ,
240+ tl_stack_nodes ,
241+ tl_construct_nodes ,
242+ sub_func_name = "" ,
243+ root_node = "" ,
244+ ):
245+ if sub_func_name and sub_func_name not in self ._sub_input_idx_map :
246+ self ._sub_input_idx_map [sub_func_name ] = {}
247+
248+ if node .op == "Placeholder" :
249+ # record placeholder node in sub functions
250+ self ._sub_map [sub_func_name ] = node
251+ self ._sub_input_idx_map [sub_func_name ][node .name ] = len (
252+ self ._sub_input_idx_map [sub_func_name ]
253+ )
254+
255+ if node .op .startswith ("TensorList" ):
256+ if is_tensor_list_constuctor (node ):
257+ tl_construct_nodes .append (node )
258+ else :
259+ for tl_write_name , idx in _tensor_list_write_ops .items ():
260+ if node .op .startswith (tl_write_name ):
261+ tl_write_nodes .append ((node , idx , sub_func_name , root_node ))
262+ if node .op .startswith ("TensorListStack" ):
263+ tl_stack_nodes .append (node )
264+ elif node .op .startswith ("StatelessWhile" ):
265+ root_node = node .name
266+ cond_fn_name , body_fn_name = [
267+ parse_attr (node .attr ).get (x ).name for x in ["cond" , "body" ]
268+ ]
269+ for fn_name in [cond_fn_name , body_fn_name ]:
270+ subfunction = self ._gdef_lib [fn_name ]
271+ sub_func_name = fn_name
272+ for sub_node in subfunction .node :
273+ # bypass const node
274+ if sub_node .op == "Const" :
275+ continue
276+ self ._tf_node_map [sub_node .name ] = sub_node
277+ self ._analysis_tensor_list_op (
278+ subfunction ,
279+ sub_node ,
280+ tl_write_nodes ,
281+ tl_stack_nodes ,
282+ tl_construct_nodes ,
283+ sub_func_name = sub_func_name ,
284+ root_node = root_node ,
285+ )
286+
287+ def _infer_static_shape_stack_node (self , tl_stack_nodes ):
288+ for stack_node in tl_stack_nodes :
289+ if len (stack_node .input ) < 2 :
290+ # Stack node does not have shape
291+ continue
292+ input_shape_name = stack_node .input [1 ].split (":" )[0 ]
293+ input_shape_node = self ._tf_node_map [input_shape_name ]
294+ stack = [self ._tf_node_map [stack_node .input [0 ].split (":" )[0 ]]]
295+ in_idx = - 1
296+ while stack :
297+ cnode = stack .pop (0 )
298+ if not cnode .op .startswith ("TensorList" ):
299+ if in_idx and cnode .op .startswith ("StatelessWhile" ):
300+ stack .append (self ._tf_node_map [cnode .input [in_idx ].split (":" )[0 ]])
301+ else :
302+ for iname in cnode .input :
303+ if self ._tf_node_map [iname .split (":" )[0 ]].op .startswith (
304+ "StatelessWhile"
305+ ):
306+ # identify input index based on output index
307+ if iname .split (":" )[1 ]:
308+ in_idx = int (iname .split (":" )[1 ])
309+ stack .append (self ._tf_node_map [iname .split (":" )[0 ]])
310+ # identify the corresponding constructor node and add shape to _tensor_list_shapes
311+ elif cnode .name != stack_node .name :
312+ if is_tensor_list_constuctor (cnode ):
313+ shape_attr = parse_attr (input_shape_node .attr )
314+ if "value" not in shape_attr :
315+ continue
316+ raw_elem_shape = tensor_util .MakeNdarray (shape_attr ["value" ])
317+ elem_shape = []
318+ for dim in raw_elem_shape :
319+ if dim < 0 :
320+ elem_shape .append (Any ())
321+ else :
322+ elem_shape .append (int (dim ))
323+ self ._tensor_list_shapes [cnode .name ] = elem_shape
324+ break
325+
326+ def _infer_static_shape_write_node (self , tl_write_nodes ):
327+ for item in tl_write_nodes :
328+ wnode = item [0 ]
329+ ta_idx , inode_idx = item [1 ]
330+ sub_func_name = item [2 ]
331+ root_name = item [3 ]
332+ stack = [self ._tf_node_map [wnode .input [ta_idx ].split (":" )[0 ]]]
333+ while stack :
334+ cnode = stack .pop (0 )
335+
336+ if not cnode .op .startswith ("TensorList" ):
337+ if cnode .op == "Placeholder" and sub_func_name :
338+ # need to map subfunction
339+ input_idx = self ._sub_input_idx_map [sub_func_name ][cnode .name ]
340+ stack .append (
341+ self ._tf_node_map [
342+ self ._tf_node_map [root_name ].input [input_idx ].split (":" )[0 ]
343+ ]
344+ )
345+ else :
346+ for iname in cnode .input :
347+ stack .append (self ._tf_node_map [iname .split (":" )[0 ]])
348+ # identify the corresponding constructor node and add it to _tensor_list_shape_nodes
349+ elif cnode .name != wnode .name :
350+ if is_tensor_list_constuctor (cnode ):
351+ inode = self ._tf_node_map [wnode .input [inode_idx ].split (":" )[0 ]]
352+ tn = wnode .input [inode_idx ].split (":" )
353+ output_index = int (tn [1 ]) if len (tn ) > 1 else 0
354+ self ._tensor_list_shape_nodes [cnode .name ] = (inode , wnode .op , output_index )
355+ break
356+
218357 def _get_relay_func (self , graph , layout = "NHWC" , shape = None , outputs = None , input_types = None ):
219358 if input_types is None :
220359 input_types = {}
221-
360+ tl_write_nodes = []
361+ tl_stack_nodes = []
362+ tl_construct_nodes = []
222363 self ._layout = layout
223364 for node in graph .node :
224365 name = node .name
@@ -235,6 +376,18 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_
235376 self ._nodes [node .name ] = sym
236377 if param :
237378 self ._params [node .name ] = param
379+ # recursivly iterate tensorlist op if seen while loop
380+ else :
381+ self ._analysis_tensor_list_op (
382+ graph , node , tl_write_nodes , tl_stack_nodes , tl_construct_nodes
383+ )
384+
385+ # Use tensor list stack to infer static tensor list shape
386+ self ._infer_static_shape_stack_node (tl_stack_nodes )
387+
388+ # Fetch node contains static tensor list shape
389+ self ._infer_static_shape_write_node (tl_write_nodes )
390+
238391 for node in graph .node :
239392 self ._backtrack_construct (graph , node .name )
240393
@@ -321,16 +474,36 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs):
321474 gdef_lib = self ._gdef_lib ,
322475 )
323476 elif op_name in _convert_map_common :
477+ # assert op are exclusive
478+ assert not set (_convert_map_common .keys ()) & set (_convert_map_tf2 .keys ())
324479 if _need_prelude_for_shape_inference (op_name ):
325480 sym = _convert_map_common [op_name ](inputs , attrs , self ._params , self ._prelude )
326481 else :
327482 sym = _convert_map_common [op_name ](inputs , attrs , self ._params , self ._module .mod )
483+ elif op_name in _convert_map_tf2 :
484+ if _need_prelude_for_shape_inference (op_name ):
485+ sym = _convert_map_tf2 [op_name ](inputs , attrs , self ._params , self ._prelude )
486+ else :
487+ sym = _convert_map_tf2 [op_name ](inputs , attrs , self ._params , self ._module .mod )
328488 else :
329489 raise NotImplementedError ("Operator {} not implemented." .format (op_name ))
330490
331491 sym = set_span (sym , node_name )
332492 return sym
333493
494+ def _parse_element_shape (self , elem_shape , shape_attr ):
495+ if "value" in shape_attr :
496+ raw_elem_shape = tensor_util .MakeNdarray (shape_attr ["value" ])
497+
498+ if raw_elem_shape .size == 1 and raw_elem_shape == - 1 :
499+ elem_shape .append (Any ())
500+ else :
501+ for dim in raw_elem_shape :
502+ if dim < 0 :
503+ elem_shape .append (Any ())
504+ else :
505+ elem_shape .append (dim )
506+
334507 def _backtrack_construct (self , graph , node_name ):
335508 """Convert a specific tensorflow node to relay expression.
336509
@@ -370,8 +543,8 @@ def _backtrack_construct(self, graph, node_name):
370543 CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), [])
371544
372545 """
373-
374546 input_op_name = node_name .split (":" )[0 ].split ("^" )[- 1 ]
547+
375548 if input_op_name not in self ._nodes :
376549 node = self ._tf_node_map [input_op_name ]
377550 attr = parse_attr (node .attr )
@@ -386,8 +559,31 @@ def _backtrack_construct(self, graph, node_name):
386559 attr ["_node_name" ] = node .name
387560 attr ["_target_layout" ] = self ._layout
388561 inputs = [self ._backtrack_construct (graph , iname ) for iname in node .input ]
389- op = self ._convert_operator (graph , node .op , node .name , inputs , attr )
390562
563+ # infer shape for TensorList op
564+ if is_tensor_list_constuctor (node ):
565+ input_shape_name = (
566+ node .input [1 ] if "TensorListFromTensor" in node .op else node .input [0 ]
567+ )
568+ input_shape_name = input_shape_name .split (":" )[0 ]
569+ input_shape_node = self ._tf_node_map [input_shape_name ]
570+ shape_attr = parse_attr (input_shape_node .attr )
571+ elem_shape = []
572+
573+ self ._parse_element_shape (elem_shape , shape_attr )
574+
575+ if elem_shape :
576+ attr ["shape" ] = elem_shape
577+ if (
578+ "identical_element_shapes" in attr and attr ["identical_element_shapes" ]
579+ ) or elem_shape :
580+ shape = elem_shape
581+ if node .name in self ._tensor_list_shapes :
582+ preset_shape = self ._tensor_list_shapes [node .name ]
583+ shape = _get_more_static_shape_rank (shape , preset_shape )
584+ attr ["shape" ] = shape
585+
586+ op = self ._convert_operator (graph , node .op , node .name , inputs , attr )
391587 if isinstance (op , np .ndarray ):
392588 self ._params [node .name ] = tvm .nd .array (op )
393589 op = [
0 commit comments