Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
# pylint: disable=import-outside-toplevel
"""ONNX: Open Neural Network Exchange frontend for Relay."""
import copy
import warnings
import numpy as np
import tvm
Expand Down Expand Up @@ -1028,10 +1029,6 @@ def _impl_v9(cls, inputs, attr, params):
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)
)

if method == "nearest_neighbor":
align_corners = False
else:
align_corners = True
# in 3d case, we use the purely static op
if dims == 5:
if isinstance(scales, _expr.Call):
Expand Down Expand Up @@ -1065,7 +1062,7 @@ def _impl_v9(cls, inputs, attr, params):
scale_w,
layout=layout,
method=method,
align_corners=align_corners,
align_corners=False,
)
return out

Expand Down Expand Up @@ -1111,17 +1108,22 @@ class Split(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
splits = attr.get("split", False)
if splits:
splits = attr.get("split", None)
if splits is not None:
indices = []
attr["indices_or_sections"] = []
index = 0
for i in splits[:-1]:
index += i
attr["indices_or_sections"].append(index)
indices.append(index)
# When splits isnt specified divide evenly over axis.
else:
attr["indices_or_sections"] = attr["tvm_custom"]["num_outputs"]
return AttrCvt("split", ignores=["split"])(inputs, attr, params)
indices = attr["tvm_custom"]["num_outputs"]
output = _op.split(inputs[0], indices, attr.get("axis", 0))
# If the output of split is a single value, unpack if from the TupleWrapper
if len(output) == 1:
output = output[0]
return output


class Slice(OnnxOpConverter):
Expand Down Expand Up @@ -1227,7 +1229,9 @@ class GatherND(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.gather_nd(inputs[0], inputs[1])
indices_dims = len(infer_shape(inputs[1]))
indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(inputs[0], indices)


class Scatter(OnnxOpConverter):
Expand Down Expand Up @@ -1538,15 +1542,6 @@ def _impl_v1(cls, inputs, attr, params):
class Tile(Elemwise):
"""Operator converter for Tile"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "repeats" not in attr:
raise tvm.error.OpAttributeInvalid(
'Attribute "repeats" should be set ' "for operator Tile."
)
reps = attr.pop("repeats") # The number of times repeating the tensor data.
return _op.tile(inputs[0], reps)

@classmethod
def _impl_v6(cls, inputs, attr, params):
return _op.tile(inputs[0], inputs[1])
Expand Down Expand Up @@ -2113,7 +2108,9 @@ def _impl_v11(cls, inputs, attr, params):
cond = inputs[1]
loop_deps = inputs[2:]
num_deps = len(loop_deps)
body = attr["body"]
# Create a copy of the body function to prevent the original
# from being modified.
body = copy.copy(attr["body"])
iter_dtype = infer_type(max_loop_count).checked_type.dtype

# Determine what condition mode we're in.
Expand Down Expand Up @@ -2150,6 +2147,8 @@ def get_var(name, val, scan=False):
checked_type = infer_type(val)
if hasattr(checked_type, "type_annotation"):
checked_type = checked_type.type_annotation
if hasattr(checked_type, "checked_type"):
checked_type = checked_type.checked_type
shape = get_const_tuple(checked_type.shape)
actual_shape = []
for dim in shape:
Expand Down Expand Up @@ -2185,8 +2184,14 @@ def get_var(name, val, scan=False):
scan_output_init = []
for i in range(num_scan_outputs):
name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps])
scan_output_vars.append(_expr.var(name, shape=([_ty.Any()] + shape), dtype=dtype))
scan_output_init.append(_op.reshape(_expr.const([]), [0] + shape))
if dtype == "float":
dtype = "float32"
scan_output_vars.append(
_expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype)
)
scan_output_init.append(
_op.reshape(_expr.const(np.array([]).astype(dtype)), [0] + [1] * len(shape))
)

# Now we can remove loop iter variables from our inner loop's inputs.
# This is kind of a hack since we have graph inputs that we don't
Expand Down Expand Up @@ -2219,18 +2224,18 @@ def body_fn(*loop_inputs):
new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)]
new_scan_outputs = [loop_outputs[i] for i in range(1 + num_deps, len(loop_outputs))]

# Increment counter.
if max_loop_count is not None:
incr = _expr.const(1, dtype=iter_dtype)
loop_count = loop_count + incr

# Add new scan outputs to tracking
combined_scan_outputs = []
for i, scan in enumerate(scan_outputs):
new_scan = _op.expand_dims(new_scan_outputs[i], axis=0)
combined_scan = _op.concatenate([scan, new_scan], axis=0)
combined_scan_outputs.append(combined_scan)

# Increment counter.
if max_loop_count is not None:
incr = _expr.const(1, dtype=iter_dtype)
loop_count = loop_count + incr

# Pack loop outputs for next iteration
# [iter_count, cond, loop_deps, loop_scans]
return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs
Expand Down Expand Up @@ -2630,12 +2635,12 @@ def _get_convert_map(opset):
"Greater": Greater.get_converter(opset),
"Less": Less.get_converter(opset),
"Log": Renamer("log"),
"ACos": Renamer("acos"),
"ACosh": Renamer("acosh"),
"ASin": Renamer("asin"),
"ASinh": Renamer("asinh"),
"ATan": Renamer("atan"),
"ATanh": Renamer("atanh"),
"Acos": Renamer("acos"),
"Acosh": Renamer("acosh"),
"Asin": Renamer("asin"),
"Asinh": Renamer("asinh"),
"Atan": Renamer("atan"),
"Atanh": Renamer("atanh"),
"Cos": Renamer("cos"),
"Cosh": Renamer("cosh"),
"Sin": Renamer("sin"),
Expand Down
Loading