Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3195,6 +3195,225 @@ def _impl_v1(cls, inputs, attr, params):
return ret


class Scan(OnnxOpConverter):
"""Operator converter for Scan"""

@classmethod
def _impl_v8(cls, inputs, attr, params):
new_inputs = inputs[1:]
batch_num = infer_shape(inputs[1])[0]
out = []
for i in range(batch_num):
v9_inputs = [
_op.take(new_inputs[j], _expr.const(i), axis=0) for j in range(len(new_inputs))
]
results = cls._impl_v9(v9_inputs, attr, params)
results = [_op.expand_dims(results[j], axis=0) for j in range(len(results))]
if i == 0:
out = results
else:
out = [_op.concatenate([out[j], results[j]], axis=0) for j in range(len(results))]

out = _expr.TupleWrapper(_expr.Tuple(out), len(out))
return out

@classmethod
def _impl_v9(cls, inputs, attr, params):
body = attr.get("body")
num_scan_inputs = attr.get("num_scan_inputs")
num_all_inputs = len(inputs)
num_state_inputs = len(body.input) - num_scan_inputs
num_state_outputs = num_state_inputs
num_all_outputs = len(body.output)
num_scan_outputs = num_all_outputs - num_state_outputs
scan_input_axes = attr.get("scan_input_axes", [0] * num_scan_inputs)
scan_input_directions = attr.get("scan_input_directions", [0] * num_scan_inputs)
scan_output_axes = list(attr.get("scan_output_axes", [0] * num_scan_outputs))
scan_output_directions = attr.get("scan_output_directions", [0] * num_scan_outputs)
# loop count are the same for all scan inputs, so get loop count by first input scan
# strided_slice not support dynamic axes, so assume input shape are static
max_loop_count = infer_shape(inputs[num_state_inputs])[scan_input_axes[0]]

# Create a copy of the body function to prevent the original
# from being modified.
body = copy.copy(attr["body"])

# Loop inputs will be packed as
# [iter_count, loop_deps, scan_outputs]
def cond_fn(*loop_inputs):
i = loop_inputs[0]
return _op.less(i, relay.const(max_loop_count, "int32"))

# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
subgraph_scope = GraphProto(
graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
)
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()

# Create a list of variables for each value updated in the loop.
def get_var(name, val, scan=False):
checked_type = infer_type(val)
if hasattr(checked_type, "type_annotation"):
checked_type = checked_type.type_annotation
if hasattr(checked_type, "checked_type"):
checked_type = checked_type.checked_type
shape = get_const_tuple(checked_type.shape)
actual_shape = []
for dim in shape:
if isinstance(dim, int) and dim == 0:
actual_shape.append(_ty.Any())
else:
actual_shape.append(dim)
if scan:
return _expr.var(name, shape=[_ty.Any()] + actual_shape, dtype=checked_type.dtype)

return _expr.var(name, shape=actual_shape, dtype=checked_type.dtype)

# Construct variables and initial empty tensors for any scan outputs.
# To do this, we'll figure out the output shapes of the body subgraph by importing
# it and doing type inference.
scan_output_vars = []
scan_output_init = []
if num_scan_outputs > 0:
with subgraph_scope:
loop_outputs = subgraph_scope.from_onnx(
body, graph_scope.opset, get_output_expr=True
)
loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output))

for i in range(num_scan_outputs):
name, _, _, _ = get_info(body.output[i + num_state_outputs])
output_node = infer_type(loop_outputs[i + num_state_outputs])
shape = list(get_const_tuple(output_node.checked_type.shape))
if scan_output_axes[i] < 0:
scan_output_axes[i] = len(shape) + scan_output_axes[i] + 1
shape.insert(scan_output_axes[i], max_loop_count)
dtype = output_node.checked_type.dtype
scan_output_vars.append(_expr.var(name, shape=shape, dtype=dtype))
scan_output_init.append(_op.zeros(shape, dtype))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems has a more elegant method(similar as operator Loop existing implementation):
scan_output_vars.append(_expr.var(name, shape=_ty.Any() * len(shape), dtype=dtype))
scan_output_init.append(_op.rehape(_expr.const(np.array[]).astype(dtype)), [0] + [1]*(len(shape)-1))
so the following _op.strided_slice could be removed, as the scan_output_init is empty.
But I did not make it work, could you take a look?

# loop vars = [iter_count, scan_state, scan_out]
loop_vars = [
_expr.var("iter", shape=(), dtype="int32"), # iteration count
]
loop_vars += [
get_var(body.input[i].name, v) for i, v in enumerate(inputs) if i < num_state_inputs
]
loop_vars += scan_output_vars
body_input_var_names = ["iter"] + [body.input[i].name for i in range(len(body.input))]

# # Now we can remove loop iter variables from our inner loop's inputs.
# # This is kind of a hack since we have graph inputs that we don't
# # want to treat as actual inputs.
while len(body.input) != 0:
body.input.pop(0)

# Define the loop body, in this function we need to unpack loop inputs,
# convert the loop subgraph, and pack outputs for the next iteration.
def body_fn(*loop_inputs):
# Unpack inputs
loop_count = loop_inputs[0]
state_vars = list(loop_inputs[1 : 1 + num_state_inputs])
scan_vars = list(loop_inputs[1 + num_state_inputs :])
# body take scan graph scan inputs as original input
input_scan_exprs = []
for i in range(num_state_inputs, num_all_inputs):
if scan_input_directions[i - num_state_inputs] != 0:
input_scan_exprs.append(
relay.take(
inputs[i],
relay.const(max_loop_count - 1, "int32") - loop_count,
axis=scan_input_axes[i - num_state_inputs],
)
)
else:
input_scan_exprs.append(
relay.take(
inputs[i],
loop_count,
axis=scan_input_axes[i - num_state_inputs],
)
)

# Prepare body inputs by adding them to node dictionary.
body_inputs = [loop_count] + state_vars + input_scan_exprs
for i, inp in enumerate(body_inputs):
subgraph_scope._nodes[body_input_var_names[i]] = inp

# Get the output of the current loop using the updated inputs.
with subgraph_scope:
loop_outputs = subgraph_scope.from_onnx(
body, graph_scope.opset, get_output_expr=True
)
# Unpack the body outputs and prepare variables for next iteration.
new_state_vars = [loop_outputs[i] for i in range(num_state_outputs)]
new_scan_vars = [loop_outputs[i] for i in range(num_state_outputs, num_all_outputs)]

# Add new scan outputs to tracking
combined_scan_outputs = []
for i in range(num_scan_outputs):
if scan_output_directions[i] == 0:
# append new scan output
combined_scan = _op.concatenate(
[scan_vars[i], _op.expand_dims(new_scan_vars[i], axis=scan_output_axes[i])],
axis=scan_output_axes[i],
)
# pop head scan output
combined_scan = _op.strided_slice(
combined_scan,
begin=[1],
end=[max_loop_count + 1],
strides=[1],
axes=[scan_output_axes[i]],
)
else:
# prepend new scan output
combined_scan = _op.concatenate(
[_op.expand_dims(new_scan_vars[i], axis=scan_output_axes[i]), scan_vars[i]],
axis=scan_output_axes[i],
)
# pop tail scan output
combined_scan = _op.strided_slice(
combined_scan,
begin=[0],
end=[max_loop_count],
strides=[1],
axes=[scan_output_axes[i]],
)
combined_scan_outputs.append(combined_scan)

incr = _expr.const(1, dtype="int32")
loop_count = loop_count + incr

# Pack loop outputs for next iteration
# [iter_count, state_var, scan_var]
return [loop_count] + new_state_vars + combined_scan_outputs

# Create the loop function.
loop = fold_constant(_loops.while_loop(cond_fn, loop_vars, body_fn))

# Now need to run initial values through the graph.
init_count = _expr.const(0, dtype="int32")

input_states = [inputs[i] for i in range(num_state_inputs)]
loop_vals = loop(init_count, *input_states, *scan_output_init)

outputs = _expr.TupleWrapper(
_expr.Tuple([_expr.TupleGetItem(loop_vals, i + 1) for i in range(num_all_outputs)]),
num_all_outputs,
)

# Update outer graph with constants found in the subgraph.
free_vars = analysis.free_vars(loop)
graph_scope._params.update(subgraph_scope._params)
graph_scope._nodes.update(subgraph_scope._nodes)
for var in free_vars:
graph_scope._nodes.update({var.name_hint: var})
return outputs


class NonMaxSuppression(OnnxOpConverter):
"""Operator converter for NonMaxSuppression."""

Expand Down Expand Up @@ -4512,6 +4731,7 @@ def _get_convert_map(opset):
"Adagrad": Adagrad.get_converter(opset),
"Adam": Adam.get_converter(opset),
"Momentum": Momentum.get_converter(opset),
"Scan": Scan.get_converter(opset),
}


Expand Down
1 change: 0 additions & 1 deletion src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break

TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option<Array<Target>>("devices");


/********** Registry **********/

TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);
Expand Down
Loading