Skip to content

Commit 0033d66

Browse files
Xingyu Zhouylc
authored andcommitted
[Frontend, Tensorflow2] Added support for TensorList ops (apache#8454)
1 parent e110f00 commit 0033d66

File tree

5 files changed

+583
-5
lines changed

5 files changed

+583
-5
lines changed

python/tvm/relay/frontend/tensorflow2.py

Lines changed: 201 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
17+
# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except, too-many-nested-blocks
1818
"""Tensorflow2.x graph to relay converter.
1919
2020
If model is constructed using tf2.x API, then use this converter:
@@ -38,12 +38,20 @@
3838
from .common import infer_type as _infer_type
3939

4040
from .tensorflow_ops import _convert_map as _convert_map_common
41-
from .tensorflow_ops import _need_prelude_for_shape_inference
41+
from .tensorflow_ops import _get_more_static_shape_rank
42+
from .tensorflow2_ops import _convert_map as _convert_map_tf2
43+
from .tensorflow2_ops import _need_prelude_for_shape_inference
4244

4345
from ..ty import Any
4446

4547
__all__ = ["from_tensorflow"]
4648

49+
# A map to record tensor list write ops and input tl/tensor indices
50+
# Value is (index of tensor list, index of written node)
51+
_tensor_list_write_ops = {
52+
"TensorListSetItem": (0, 2),
53+
}
54+
4755

4856
def _infer_type_with_prelude(val, prelude):
4957
body = _infer_type(val, prelude.mod)
@@ -66,6 +74,11 @@ def set_span(sym, node_name):
6674
return sym
6775

6876

77+
def is_tensor_list_constuctor(tf_node):
78+
"""Check whether is tensor list constructor node."""
79+
return tf_node.op == "TensorListReserve"
80+
81+
6982
def convert_const_node(node, shape):
7083
"""convert tf const node into relay const or var"""
7184

@@ -196,6 +209,10 @@ def __init__(self, module):
196209
self._output_shapes = {}
197210
self._tf_node_map = {}
198211
self._gdef_lib = {}
212+
self._tensor_list_shapes = {}
213+
self._tensor_list_shape_nodes = {}
214+
self._sub_map = {}
215+
self._sub_input_idx_map = {}
199216

200217
def from_tensorflow(
201218
self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None
@@ -215,10 +232,134 @@ def from_tensorflow(
215232
)
216233
return func, self._params
217234

235+
def _analysis_tensor_list_op(
236+
self,
237+
graph,
238+
node,
239+
tl_write_nodes,
240+
tl_stack_nodes,
241+
tl_construct_nodes,
242+
sub_func_name="",
243+
root_node="",
244+
):
245+
if sub_func_name and sub_func_name not in self._sub_input_idx_map:
246+
self._sub_input_idx_map[sub_func_name] = {}
247+
248+
if node.op == "Placeholder":
249+
# record placeholder node in sub functions
250+
self._sub_map[sub_func_name] = node
251+
self._sub_input_idx_map[sub_func_name][node.name] = len(
252+
self._sub_input_idx_map[sub_func_name]
253+
)
254+
255+
if node.op.startswith("TensorList"):
256+
if is_tensor_list_constuctor(node):
257+
tl_construct_nodes.append(node)
258+
else:
259+
for tl_write_name, idx in _tensor_list_write_ops.items():
260+
if node.op.startswith(tl_write_name):
261+
tl_write_nodes.append((node, idx, sub_func_name, root_node))
262+
if node.op.startswith("TensorListStack"):
263+
tl_stack_nodes.append(node)
264+
elif node.op.startswith("StatelessWhile"):
265+
root_node = node.name
266+
cond_fn_name, body_fn_name = [
267+
parse_attr(node.attr).get(x).name for x in ["cond", "body"]
268+
]
269+
for fn_name in [cond_fn_name, body_fn_name]:
270+
subfunction = self._gdef_lib[fn_name]
271+
sub_func_name = fn_name
272+
for sub_node in subfunction.node:
273+
# bypass const node
274+
if sub_node.op == "Const":
275+
continue
276+
self._tf_node_map[sub_node.name] = sub_node
277+
self._analysis_tensor_list_op(
278+
subfunction,
279+
sub_node,
280+
tl_write_nodes,
281+
tl_stack_nodes,
282+
tl_construct_nodes,
283+
sub_func_name=sub_func_name,
284+
root_node=root_node,
285+
)
286+
287+
def _infer_static_shape_stack_node(self, tl_stack_nodes):
288+
for stack_node in tl_stack_nodes:
289+
if len(stack_node.input) < 2:
290+
# Stack node does not have shape
291+
continue
292+
input_shape_name = stack_node.input[1].split(":")[0]
293+
input_shape_node = self._tf_node_map[input_shape_name]
294+
stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]]
295+
in_idx = -1
296+
while stack:
297+
cnode = stack.pop(0)
298+
if not cnode.op.startswith("TensorList"):
299+
if in_idx and cnode.op.startswith("StatelessWhile"):
300+
stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]])
301+
else:
302+
for iname in cnode.input:
303+
if self._tf_node_map[iname.split(":")[0]].op.startswith(
304+
"StatelessWhile"
305+
):
306+
# identify input index based on output index
307+
if iname.split(":")[1]:
308+
in_idx = int(iname.split(":")[1])
309+
stack.append(self._tf_node_map[iname.split(":")[0]])
310+
# identify the corresponding constructor node and add shape to _tensor_list_shapes
311+
elif cnode.name != stack_node.name:
312+
if is_tensor_list_constuctor(cnode):
313+
shape_attr = parse_attr(input_shape_node.attr)
314+
if "value" not in shape_attr:
315+
continue
316+
raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"])
317+
elem_shape = []
318+
for dim in raw_elem_shape:
319+
if dim < 0:
320+
elem_shape.append(Any())
321+
else:
322+
elem_shape.append(int(dim))
323+
self._tensor_list_shapes[cnode.name] = elem_shape
324+
break
325+
326+
def _infer_static_shape_write_node(self, tl_write_nodes):
327+
for item in tl_write_nodes:
328+
wnode = item[0]
329+
ta_idx, inode_idx = item[1]
330+
sub_func_name = item[2]
331+
root_name = item[3]
332+
stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]]
333+
while stack:
334+
cnode = stack.pop(0)
335+
336+
if not cnode.op.startswith("TensorList"):
337+
if cnode.op == "Placeholder" and sub_func_name:
338+
# need to map subfunction
339+
input_idx = self._sub_input_idx_map[sub_func_name][cnode.name]
340+
stack.append(
341+
self._tf_node_map[
342+
self._tf_node_map[root_name].input[input_idx].split(":")[0]
343+
]
344+
)
345+
else:
346+
for iname in cnode.input:
347+
stack.append(self._tf_node_map[iname.split(":")[0]])
348+
# identify the corresponding constructor node and add it to _tensor_list_shape_nodes
349+
elif cnode.name != wnode.name:
350+
if is_tensor_list_constuctor(cnode):
351+
inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]]
352+
tn = wnode.input[inode_idx].split(":")
353+
output_index = int(tn[1]) if len(tn) > 1 else 0
354+
self._tensor_list_shape_nodes[cnode.name] = (inode, wnode.op, output_index)
355+
break
356+
218357
def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types=None):
219358
if input_types is None:
220359
input_types = {}
221-
360+
tl_write_nodes = []
361+
tl_stack_nodes = []
362+
tl_construct_nodes = []
222363
self._layout = layout
223364
for node in graph.node:
224365
name = node.name
@@ -235,6 +376,18 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_
235376
self._nodes[node.name] = sym
236377
if param:
237378
self._params[node.name] = param
379+
# recursivly iterate tensorlist op if seen while loop
380+
else:
381+
self._analysis_tensor_list_op(
382+
graph, node, tl_write_nodes, tl_stack_nodes, tl_construct_nodes
383+
)
384+
385+
# Use tensor list stack to infer static tensor list shape
386+
self._infer_static_shape_stack_node(tl_stack_nodes)
387+
388+
# Fetch node contains static tensor list shape
389+
self._infer_static_shape_write_node(tl_write_nodes)
390+
238391
for node in graph.node:
239392
self._backtrack_construct(graph, node.name)
240393

@@ -321,16 +474,36 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs):
321474
gdef_lib=self._gdef_lib,
322475
)
323476
elif op_name in _convert_map_common:
477+
# assert op are exclusive
478+
assert not set(_convert_map_common.keys()) & set(_convert_map_tf2.keys())
324479
if _need_prelude_for_shape_inference(op_name):
325480
sym = _convert_map_common[op_name](inputs, attrs, self._params, self._prelude)
326481
else:
327482
sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod)
483+
elif op_name in _convert_map_tf2:
484+
if _need_prelude_for_shape_inference(op_name):
485+
sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._prelude)
486+
else:
487+
sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._module.mod)
328488
else:
329489
raise NotImplementedError("Operator {} not implemented.".format(op_name))
330490

331491
sym = set_span(sym, node_name)
332492
return sym
333493

494+
def _parse_element_shape(self, elem_shape, shape_attr):
495+
if "value" in shape_attr:
496+
raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"])
497+
498+
if raw_elem_shape.size == 1 and raw_elem_shape == -1:
499+
elem_shape.append(Any())
500+
else:
501+
for dim in raw_elem_shape:
502+
if dim < 0:
503+
elem_shape.append(Any())
504+
else:
505+
elem_shape.append(dim)
506+
334507
def _backtrack_construct(self, graph, node_name):
335508
"""Convert a specific tensorflow node to relay expression.
336509
@@ -370,8 +543,8 @@ def _backtrack_construct(self, graph, node_name):
370543
CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), [])
371544
372545
"""
373-
374546
input_op_name = node_name.split(":")[0].split("^")[-1]
547+
375548
if input_op_name not in self._nodes:
376549
node = self._tf_node_map[input_op_name]
377550
attr = parse_attr(node.attr)
@@ -386,8 +559,31 @@ def _backtrack_construct(self, graph, node_name):
386559
attr["_node_name"] = node.name
387560
attr["_target_layout"] = self._layout
388561
inputs = [self._backtrack_construct(graph, iname) for iname in node.input]
389-
op = self._convert_operator(graph, node.op, node.name, inputs, attr)
390562

563+
# infer shape for TensorList op
564+
if is_tensor_list_constuctor(node):
565+
input_shape_name = (
566+
node.input[1] if "TensorListFromTensor" in node.op else node.input[0]
567+
)
568+
input_shape_name = input_shape_name.split(":")[0]
569+
input_shape_node = self._tf_node_map[input_shape_name]
570+
shape_attr = parse_attr(input_shape_node.attr)
571+
elem_shape = []
572+
573+
self._parse_element_shape(elem_shape, shape_attr)
574+
575+
if elem_shape:
576+
attr["shape"] = elem_shape
577+
if (
578+
"identical_element_shapes" in attr and attr["identical_element_shapes"]
579+
) or elem_shape:
580+
shape = elem_shape
581+
if node.name in self._tensor_list_shapes:
582+
preset_shape = self._tensor_list_shapes[node.name]
583+
shape = _get_more_static_shape_rank(shape, preset_shape)
584+
attr["shape"] = shape
585+
586+
op = self._convert_operator(graph, node.op, node.name, inputs, attr)
391587
if isinstance(op, np.ndarray):
392588
self._params[node.name] = tvm.nd.array(op)
393589
op = [

0 commit comments

Comments
 (0)