Skip to content

Commit d8833bd

Browse files
vvchernovValery Chernov
andauthored
[ONNX] Support SequenceEmpty op (#13866)
* add SequenceEmpty * add SequenceEmpty test * pylint fix --------- Co-authored-by: Valery Chernov <[email protected]>
1 parent 0d5baac commit d8833bd

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6148,6 +6148,15 @@ def _impl_v11(cls, inputs, attr, params):
61486148
return _expr.Tuple(inputs)
61496149

61506150

6151+
class SequenceEmpty(OnnxOpConverter):
6152+
"""Operator converter for sequence empty op."""
6153+
6154+
@classmethod
6155+
def _impl_v11(cls, inputs, attr, params):
6156+
# Construct an empty tuple.
6157+
return _expr.Tuple([])
6158+
6159+
61516160
class SequenceErase(OnnxOpConverter):
61526161
"""Operator converter for sequence erase op."""
61536162

@@ -6523,6 +6532,7 @@ def _get_convert_map(opset):
65236532
"LinearRegressor": LinearRegressor.get_converter(opset),
65246533
# Sequence operators
65256534
"SequenceConstruct": SequenceConstruct.get_converter(opset),
6535+
"SequenceEmpty": SequenceEmpty.get_converter(opset),
65266536
"SequenceErase": SequenceErase.get_converter(opset),
65276537
"SequenceInsert": SequenceInsert.get_converter(opset),
65286538
"SequenceLength": SequenceLength.get_converter(opset),

tests/python/frontend/onnx/test_forward.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7829,6 +7829,38 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=
78297829
verify_sequence_ops((3, 3, 3, 3), 4, axis=2, new_axis=1)
78307830

78317831

7832+
@tvm.testing.parametrize_targets
7833+
def test_empty_sequence(target, dev):
7834+
"""test_empty_sequence"""
7835+
7836+
# Test creating an empty tensor sequence.
7837+
empty_node = helper.make_node(
7838+
"SequenceEmpty",
7839+
inputs=[],
7840+
outputs=["empty_sequence"],
7841+
)
7842+
7843+
length_node = helper.make_node("SequenceLength", inputs=["empty_sequence"], outputs=["output"])
7844+
7845+
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.INT64, [])]
7846+
7847+
graph_nodes = [empty_node, length_node]
7848+
7849+
graph = helper.make_graph(
7850+
graph_nodes,
7851+
"Sequence_empty_test",
7852+
inputs=[],
7853+
outputs=graph_outputs,
7854+
)
7855+
7856+
model = helper.make_model(
7857+
graph,
7858+
producer_name="Sequence_empty_test",
7859+
)
7860+
7861+
verify_with_ort_with_inputs(model, [], target=target, dev=dev)
7862+
7863+
78327864
def test_exporting_node_renamed_model():
78337865
"""test exproting model when export_node_renamed_model is set"""
78347866

0 commit comments

Comments
 (0)