Skip to content

Commit db6915d

Browse files
author
cheng wen
committed
fix clip unsqueeze opset implement
1 parent 94e83f2 commit db6915d

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class Unsqueeze(OnnxOpConverter):
293293
"""Converts an onnx Unsqueeze node into an equivalent Relax expression."""
294294

295295
@classmethod
296-
def _impl_v11(cls, bb, inputs, attr, params):
296+
def _impl_v1(cls, bb, inputs, attr, params):
297297
axes = list(attr.get("axes"))
298298
inputs = inputs + [relax.const(axes, "int64")]
299299
return cls._impl_v13(bb, inputs, attr, params)
@@ -570,6 +570,15 @@ def _impl_v16(cls, bb, inputs, attr, params):
570570
class Clip(OnnxOpConverter):
571571
"""Converts an onnx Clip node into an equivalent Relax expression."""
572572

573+
@classmethod
574+
def _impl_v1(cls, bb, inputs, attr, params):
575+
min = float(attr.get("min", -_np.inf))
576+
max = float(attr.get("max", _np.inf))
577+
results = inputs[0]
578+
results = bb.emit_te(topi.maximum, results, min)
579+
results = bb.emit_te(topi.minimum, results, max)
580+
return results
581+
573582
@classmethod
574583
def _impl_v13(cls, bb, inputs, attr, params):
575584
results = inputs[0]

tests/python/relax/test_frontend_onnx.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def check_correctness(
148148
tvm_num_outputs = 1
149149

150150
# Check that number of outputs match.
151-
152151
assert tvm_num_outputs == len(ort_output), "Unequal number of outputs"
153152

154153
for (tvm_out, ort_out) in zip(tvm_output, ort_output):
@@ -435,6 +434,22 @@ def test_unsqueeze():
435434
check_correctness(model)
436435

437436

437+
def test_unsqueeze_v1():
438+
# https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1
439+
unsqueeze_node = helper.make_node("Unsqueeze", ["a"], ["b"], axes=[0, 2, 3])
440+
graph = helper.make_graph(
441+
[unsqueeze_node],
442+
"unsqueeze_v1",
443+
inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])],
444+
outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 1, 1, 32])],
445+
)
446+
447+
model = helper.make_model(
448+
graph, producer_name="unsqueeze_v1_test", opset_imports=[helper.make_opsetid("", 6)]
449+
)
450+
check_correctness(model, opset=10)
451+
452+
438453
def test_gelu():
439454
verify_unary("Gelu", [32, 32], domain="com.microsoft")
440455

@@ -490,6 +505,25 @@ def test_clip(min, max):
490505
check_correctness(model)
491506

492507

508+
@pytest.mark.parametrize("min", [-6.0, 0.0])
509+
@pytest.mark.parametrize("max", [6.0])
510+
def test_clip_v6(max, min):
511+
# https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Clip-6
512+
clip_node = helper.make_node("Clip", ["input"], ["output"], max=max, min=min)
513+
inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 64])]
514+
graph = helper.make_graph(
515+
[clip_node],
516+
"clip_v6_test",
517+
inputs=inputs,
518+
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 64])],
519+
)
520+
model = helper.make_model(
521+
graph, producer_name="clip_v6_test", opset_imports=[helper.make_opsetid("", 6)]
522+
)
523+
onnx.save(model, "a.onnx")
524+
check_correctness(model, opset=10)
525+
526+
493527
def test_equal():
494528
equal_node = helper.make_node("Equal", ["a", "b"], ["output"])
495529

0 commit comments

Comments
 (0)