Skip to content

Commit 1ffec42

Browse files
soiferjMarisaKirisame
authored andcommitted
[Relay][Topi][TensorFlow][ONNX][Lang] Add support for Any op (apache#4205)
* Add support for Any op * Support ONNX frontend * Add doc * Add to relay docs * Dummy change to retrigger CI
1 parent ed38791 commit 1ffec42

File tree

17 files changed

+256
-4
lines changed

17 files changed

+256
-4
lines changed

docs/api/python/topi.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ List of operators
9191
topi.greater_equal
9292
topi.less_equal
9393
topi.all
94+
topi.any
9495
topi.logical_and
9596
topi.logical_or
9697
topi.logical_not
@@ -151,6 +152,7 @@ topi
151152
.. autofunction:: topi.full
152153
.. autofunction:: topi.full_like
153154
.. autofunction:: topi.all
155+
.. autofunction:: topi.any
154156
.. autofunction:: topi.max
155157
.. autofunction:: topi.sum
156158
.. autofunction:: topi.min

docs/frontend/tensorflow.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ Supported Ops
116116
- Abs
117117
- Add
118118
- All
119+
- Any
119120
- ArgMax
120121
- ArgMin
121122
- AvgPool

docs/langref/relay_op.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ This level enables additional math and transform operators.
137137
tvm.relay.less
138138
tvm.relay.less_equal
139139
tvm.relay.all
140+
tvm.relay.any
140141
tvm.relay.logical_and
141142
tvm.relay.logical_or
142143
tvm.relay.logical_not
@@ -300,6 +301,7 @@ Level 4 Definitions
300301
.. autofunction:: tvm.relay.less
301302
.. autofunction:: tvm.relay.less_equal
302303
.. autofunction:: tvm.relay.all
304+
.. autofunction:: tvm.relay.any
303305
.. autofunction:: tvm.relay.logical_and
304306
.. autofunction:: tvm.relay.logical_or
305307
.. autofunction:: tvm.relay.logical_not

include/tvm/expr_operator.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,13 @@ TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
519519
*/
520520
TVM_DLL Expr all(Expr source, Array<IterVar> axis);
521521

522+
/*!
523+
* \brief logical Or of of source expression over axis
524+
* \param source The source expression.
525+
* \param axis List of iteration variables that will be used for reduction.
526+
*/
527+
TVM_DLL Expr any(Expr source, Array<IterVar> axis);
528+
522529
/*!
523530
* \brief max of of source expression over axis
524531
* \param source The source expression.

python/tvm/relay/frontend/onnx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,12 @@ class Where(OnnxOpConverter):
989989
def _impl_v9(cls, inputs, attr, params):
990990
return _op.where(inputs[0], inputs[1], inputs[2])
991991

992+
class Or(Elemwise):
993+
""" Operator converter for Or.
994+
"""
995+
@classmethod
996+
def _impl_v7(cls, inputs, attr, params):
997+
return _op.logical_or(inputs[0], inputs[1])
992998

993999
# compatible operators that do NOT require any conversion.
9941000
_identity_list = []
@@ -1111,7 +1117,8 @@ def _get_convert_map(opset):
11111117
'And': And.get_converter(opset),
11121118
'Tile': Tile.get_converter(opset),
11131119
'Erf': Erf.get_converter(opset),
1114-
'Where': Where.get_converter(opset)
1120+
'Where': Where.get_converter(opset),
1121+
'Or': Or.get_converter(opset)
11151122
}
11161123

11171124

python/tvm/relay/frontend/tensorflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,7 @@ def _impl(inputs, attr, params):
13301330
'Abs' : AttrCvt('abs'),
13311331
'Add' : _elemwise('add'),
13321332
'All' : _reduce('all'),
1333+
'Any' : _reduce('any'),
13331334
'ArgMax' : _argx(_op.argmax, 'argmax'),
13341335
'ArgMin' : _argx(_op.argmin, 'argmin'),
13351336
'Assert' : _assert(),

python/tvm/relay/op/_reduce.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _schedule_reduce(_, outs, target):
3131
_reg.register_schedule("argmin", _schedule_reduce)
3232
_reg.register_schedule("sum", _schedule_reduce)
3333
_reg.register_schedule("all", _schedule_reduce)
34+
_reg.register_schedule("any", _schedule_reduce)
3435
_reg.register_schedule("max", _schedule_reduce)
3536
_reg.register_schedule("min", _schedule_reduce)
3637
_reg.register_schedule("prod", _schedule_reduce)

python/tvm/relay/op/reduce.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,58 @@ def all(data, axis=None, keepdims=False, exclude=False):
166166
return _make.all(data, axis, keepdims, exclude)
167167

168168

169+
def any(data, axis=None, keepdims=False, exclude=False):
170+
"""Computes the logical OR of boolean array elements over given axes.
171+
172+
Parameters
173+
----------
174+
data : relay.Expr
175+
The input boolean tensor
176+
177+
axis : None or int or tuple of int
178+
Axis or axes along which a sum is performed. The default, axis=None,
179+
will sum all of the elements of the input array. If axis is
180+
negative it counts from the last to the first axis.
181+
182+
keepdims : bool
183+
If this is set to True, the axes which are reduced are left in the result as
184+
dimensions with size one. With this option, the result will broadcast
185+
correctly against the input array.
186+
187+
exclude : bool
188+
If `exclude` is true, reduction will be performed on the axes that are
189+
NOT in axis instead.
190+
191+
Returns
192+
-------
193+
result : relay.Expr
194+
The computed result.
195+
196+
Examples
197+
--------
198+
.. code-block:: python
199+
200+
data = relay.Constant(tvm.nd.array([[[ True, True, True],
201+
[ True, True, True],
202+
[False, True, False]],
203+
[[ True, False, False],
204+
[ True, True, False],
205+
[False, True, True]]]))
206+
207+
relay.any(data, axis=1)
208+
# [[True, True, True],
209+
# [True, True, True]]
210+
211+
relay.any(data, axis=0)
212+
# [[ True, True, True],
213+
# [ True, True, True],
214+
# [False, True, True]]
215+
216+
"""
217+
axis = [axis] if isinstance(axis, int) else axis
218+
return _make.any(data, axis, keepdims, exclude)
219+
220+
169221
def max(data, axis=None, keepdims=False, exclude=False):
170222
""" Computes the max of array elements over given axes.
171223

src/lang/expr_operator.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,16 @@ Expr all(Expr source, Array<IterVar> rdom) {
486486
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
487487
}
488488

489+
Expr any(Expr source, Array<IterVar> rdom) {
490+
CHECK(source.type().is_bool());
491+
Var x("x", source.type()), y("y", source.type());
492+
Expr result = ir::Or::make(x, y);
493+
Expr identity_element = make_const(source.type(), false);
494+
ir::CommReducer combiner =
495+
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
496+
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
497+
}
498+
489499
Expr max(Expr source, Array<IterVar> rdom) {
490500
Var x("x", source.type()), y("y", source.type());
491501
Expr result = ir::Max::make(x, y);

src/relay/op/tensor/reduce.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,43 @@ Example::
420420
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
421421

422422

423+
Array<Tensor> AnyCompute(const Attrs& attrs,
424+
const Array<Tensor>& inputs,
425+
const Type& out_type,
426+
const Target& target) {
427+
return ReduceCompute(attrs, inputs, out_type, target, topi::any);
428+
}
429+
430+
431+
RELAY_REGISTER_REDUCE_OP("any")
432+
.describe(R"code(Computes the logical OR of boolean array elements over given axes.
433+
434+
Example::
435+
436+
data = [[[ True, True, True],
437+
[ True, True, True],
438+
[False, True, False]],
439+
[[ True, False, False],
440+
[ True, True, False],
441+
[False, True, True]]]
442+
443+
any(data, axis=1)
444+
[[True, True, True],
445+
[True, True, True]]
446+
447+
any(data, axis=0)
448+
[[ True, True, True],
449+
[ True, True, True],
450+
[False, True, True]]
451+
452+
)code" TVM_ADD_FILELINE)
453+
.set_attrs_type<ReduceAttrs>()
454+
.set_support_level(4)
455+
.add_type_rel("Reduce", ReduceRel)
456+
.set_attr<FTVMCompute>("FTVMCompute", AnyCompute)
457+
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
458+
459+
423460
Array<Tensor> MaxCompute(const Attrs& attrs,
424461
const Array<Tensor>& inputs,
425462
const Type& out_type,

0 commit comments

Comments
 (0)