Skip to content

Commit f94942d

Browse files
jwfrommtrevor-m
authored andcommitted
Allow condition in if op to be an array. (apache#7215)
1 parent 9d5ee9c commit f94942d

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,6 +2266,9 @@ class If(OnnxOpConverter):
22662266
@classmethod
22672267
def _impl_v1(cls, inputs, attr, params):
22682268
cond = inputs[0]
2269+
# Convert array to bool if needed.
2270+
if len(infer_shape(cond)) > 0:
2271+
cond = _op.take(cond, _expr.const(0, dtype="int64"))
22692272
then_branch = attr.get("then_branch", None)
22702273
else_branch = attr.get("else_branch", None)
22712274
assert then_branch is not None and else_branch is not None

tests/python/frontend/onnx/test_forward.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3969,8 +3969,7 @@ def test_loop():
39693969
verify_count_loop()
39703970

39713971

3972-
@tvm.testing.uses_gpu
3973-
def test_if():
3972+
def verify_if(cond_array):
39743973
# Given a bool scalar input cond.
39753974
# return constant tensor x if cond is True, otherwise return constant tensor y.
39763975
then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5])
@@ -4007,7 +4006,10 @@ def test_if():
40074006
)
40084007

40094008
if_model = onnx.helper.make_model(if_graph)
4010-
cond = np.array(1).astype("bool")
4009+
if cond_array:
4010+
cond = np.array([1]).astype("bool")
4011+
else:
4012+
cond = np.array(1).astype("bool")
40114013
correct_out = x if cond else y
40124014

40134015
for target, ctx in tvm.testing.enabled_targets():
@@ -4016,6 +4018,13 @@ def test_if():
40164018
tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05)
40174019

40184020

4021+
@tvm.testing.uses_gpu
4022+
def test_if():
4023+
# Confirm that if works with cond as an array or scalar.
4024+
verify_if(cond_array=False)
4025+
verify_if(cond_array=True)
4026+
4027+
40194028
@tvm.testing.uses_gpu
40204029
def test_size():
40214030
def verify_size(indata):

0 commit comments

Comments
 (0)