Skip to content

Commit 4560d10

Browse files
author
Matthew Brookhart
committed
add onnx resize v10 and unit test
1 parent 1831c17 commit 4560d10

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,25 @@ def _impl_v7(cls, inputs, attr, params):
18701870
class Resize(OnnxOpConverter):
18711871
"""Operator converter for Resize"""
18721872

1873+
@classmethod
1874+
def _impl_v10(cls, inputs, attr, params):
1875+
mode = attr.get("mode")
1876+
if mode == b"nearest":
1877+
method = "nearest_neighbor"
1878+
elif mode == b"linear":
1879+
method = "bilinear"
1880+
else:
1881+
raise tvm.error.OpAttributeInvalid(
1882+
'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)
1883+
)
1884+
1885+
scale = inputs[1]
1886+
size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
1887+
1888+
layout = "NCHW" # ONNX assumes NCHW layout
1889+
out_size = _op.strided_slice(size, [2], [4])
1890+
return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric")
1891+
18731892
@classmethod
18741893
def _impl_v11(cls, inputs, attr, params):
18751894
mode = attr.get("mode")
@@ -1891,9 +1910,7 @@ def _impl_v11(cls, inputs, attr, params):
18911910
size = inputs[3]
18921911
else:
18931912
assert len(scale_shape) != 0, "One of scale or size should be passed."
1894-
size = (
1895-
_op.cast(_op.shape_of(inputs[0]), infer_type(scale).type_annotation.dtype) * scale
1896-
)
1913+
size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
18971914

18981915
coord_trans = attr.get("coordinate_transformation_mode")
18991916
if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]:

tests/python/frontend/onnx/test_forward.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3525,6 +3525,36 @@ def verify(ishape, oshape, scales, mode, coord_trans):
35253525
verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric")
35263526
verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel")
35273527

3528+
def verify_opset_10(ishape, scales, mode):
3529+
nodes = [
3530+
make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales),
3531+
]
3532+
input_names = ["X", "scales"]
3533+
nodes.append(
3534+
helper.make_node(
3535+
"Resize",
3536+
inputs=input_names,
3537+
outputs=["Y"],
3538+
mode=mode,
3539+
)
3540+
)
3541+
3542+
oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)]
3543+
graph = helper.make_graph(
3544+
nodes,
3545+
"resize_test",
3546+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, ishape)],
3547+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, oshape)],
3548+
)
3549+
3550+
model = helper.make_model(graph, producer_name="resize_test")
3551+
model.opset_import[0].version = 10
3552+
3553+
verify_with_ort(model, [ishape], oshape, use_vm=True, freeze_params=True)
3554+
3555+
verify_opset_10([1, 16, 32, 32], [1, 1, 2, 2], "nearest")
3556+
verify_opset_10([1, 16, 32, 32], [1, 1, 0.5, 0.5], "linear")
3557+
35283558

35293559
@tvm.testing.uses_gpu
35303560
def test_nonzero():

0 commit comments

Comments
 (0)