Skip to content

Commit 4c6398b

Browse files
committed
Add op argwhere
1 parent 8eb3157 commit 4c6398b

File tree

12 files changed

+390
-2
lines changed

12 files changed

+390
-2
lines changed

include/tvm/relay/attrs/algorithm.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
namespace tvm {
3232
namespace relay {
3333

34+
/*! \brief Attributes for ArgWhere operator */
35+
struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
36+
TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
37+
}
38+
};
39+
3440
/*! \brief Attributes used in argsort operators */
3541
struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
3642
int axis;

python/tvm/relay/op/_algorithm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"Definition of classic algorithms"
1818
# pylint: disable=invalid-name,unused-argument
1919
from __future__ import absolute_import
20+
import tvm
21+
from tvm.relay.ty import TensorType
2022

2123
import topi
2224
from topi.util import get_const_int
@@ -41,6 +43,26 @@ def compute_argsort(attrs, inputs, _, target):
4143

4244
register_pattern("argsort", OpPattern.OPAQUE)
4345

46+
# argwhere
47+
@register_schedule("argwhere")
48+
def schedule_argwhere(_, outs, target):
49+
"""Schedule definition of argwhere"""
50+
with target:
51+
return topi.generic.schedule_argwhere(outs)
52+
53+
54+
@register_compute("argwhere")
55+
def compute_argwhere(attrs, inputs, output_type, _):
56+
"""Compute definition of argwhere"""
57+
output_shape = []
58+
for s in output_type.shape:
59+
if hasattr(s, "value"):
60+
output_shape.append(s)
61+
else:
62+
# see Any, replace it with a var
63+
output_shape.append(tvm.var("any_dim", "int32"))
64+
new_output_type = TensorType(output_shape, "int32")
65+
return [topi.argwhere(new_output_type, inputs[0])]
4466

4567
@register_schedule("topk")
4668
def schedule_topk(_, outs, target):

python/tvm/relay/op/_transform.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Backend compiler related feature registration"""
18-
# pylint: disable=invalid-name,unused-argument, len-as-condition
18+
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
1919
from __future__ import absolute_import
2020
from topi.util import get_const_int, get_const_tuple
2121
from . import op as _reg
@@ -204,3 +204,68 @@ def take_shape_func(attrs, inputs, out_ndims):
204204
axis += data_ndim
205205
assert 0 <= axis < data_ndim
206206
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
207+
208+
@script
209+
def _argwhere_shape_func_2d(condition):
210+
out = output_tensor((2, ), "int64")
211+
out[0] = int64(0)
212+
out[1] = int64(2)
213+
for i1 in range(condition.shape[0]):
214+
for i2 in range(condition.shape[1]):
215+
if condition[i1, i2]:
216+
out[0] += int64(1)
217+
return out
218+
219+
@script
220+
def _argwhere_shape_func_3d(condition):
221+
out = output_tensor((2, ), "int64")
222+
out[0] = int64(0)
223+
out[1] = int64(3)
224+
for i1 in range(condition.shape[0]):
225+
for i2 in range(condition.shape[1]):
226+
for i3 in range(condition.shape[2]):
227+
if condition[i1, i2, i3]:
228+
out[0] += int64(1)
229+
return out
230+
231+
@script
232+
def _argwhere_shape_func_4d(condition):
233+
out = output_tensor((2, ), "int64")
234+
out[0] = int64(0)
235+
out[1] = int64(4)
236+
for i1 in range(condition.shape[0]):
237+
for i2 in range(condition.shape[1]):
238+
for i3 in range(condition.shape[2]):
239+
for i4 in range(condition.shape[3]):
240+
if condition[i1, i2, i3, i4]:
241+
out[0] += int64(1)
242+
return out
243+
244+
@script
245+
def _argwhere_shape_func_5d(condition):
246+
out = output_tensor((2, ), "int64")
247+
out[0] = int64(0)
248+
out[1] = int64(5)
249+
for i1 in range(condition.shape[0]):
250+
for i2 in range(condition.shape[1]):
251+
for i3 in range(condition.shape[2]):
252+
for i4 in range(condition.shape[3]):
253+
for i5 in range(condition.shape[4]):
254+
if condition[i1, i2, i3, i4, i5]:
255+
out[0] += int64(1)
256+
return out
257+
258+
@_reg.register_shape_func("argwhere", True)
259+
def argwhere_shape_func(attrs, inputs, out_ndims):
260+
"""
261+
Shape function for argwhere.
262+
"""
263+
if len(inputs[0].shape) == 2:
264+
return [_argwhere_shape_func_2d(inputs[0])]
265+
elif len(inputs[0].shape) == 3:
266+
return [_argwhere_shape_func_3d(inputs[0])]
267+
elif len(inputs[0].shape) == 4:
268+
return [_argwhere_shape_func_4d(inputs[0])]
269+
elif len(inputs[0].shape) == 5:
270+
return [_argwhere_shape_func_5d(inputs[0])]
271+
return []

python/tvm/relay/op/transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def squeeze(data, axis=None):
144144
"""
145145
return _make.squeeze(data, axis)
146146

147+
def argwhere(condition):
148+
return _make.argwhere(condition)
147149

148150
def reshape(data, newshape):
149151
"""Reshapes the input array.

src/relay/op/algorithm/argwhere.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file argwhere.cc
23+
* \brief Argwhere operators
24+
*/
25+
#include <tvm/relay/op.h>
26+
#include <tvm/relay/attrs/algorithm.h>
27+
#include <tvm/relay/op_attr_types.h>
28+
29+
namespace tvm {
30+
namespace relay {
31+
32+
// ArgWhere
33+
bool ArgWhereRel(const Array<Type>& types,
34+
int num_inputs,
35+
const Attrs& attrs,
36+
const TypeReporter& reporter) {
37+
CHECK_EQ(num_inputs, 1);
38+
auto tt = types[0].as<TensorTypeNode>();
39+
CHECK(tt != nullptr);
40+
const auto& input_shape = tt->shape;
41+
const auto& input_rank = input_shape.size();
42+
std::vector<IndexExpr> result_shape;
43+
result_shape.push_back(Any::make());
44+
result_shape.push_back(IntImm::make(Int(32), input_rank));
45+
reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32)));
46+
return true;
47+
}
48+
49+
TVM_REGISTER_API("relay.op._make.argwhere")
50+
.set_body_typed<Expr(Expr)>([](Expr data) {
51+
static const Op& op = Op::Get("argwhere");
52+
auto attrs = make_node<ArgWhereAttrs>();
53+
return CallNode::make(op, {data}, Attrs(attrs), {});
54+
});
55+
56+
RELAY_REGISTER_OP("argwhere")
57+
.describe(R"doc(Find the indices of elements of a tensor that are
58+
non-zero)doc" TVM_ADD_FILELINE)
59+
.set_num_inputs(1)
60+
.set_attrs_type_key("relay.attrs.ArgWhereAttrs")
61+
.add_argument("condition", "Tensor", "The input condition tensor.")
62+
.add_type_rel("ArgWhere", ArgWhereRel)
63+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
64+
.set_attr<TOpPattern>("TOpPattern", kOpaque)
65+
.set_support_level(10);
66+
67+
} // namespace relay
68+
} // namespace tvm

src/relay/op/tensor/unary.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ RELAY_REGISTER_OP("shape_of")
326326
.set_support_level(10)
327327
.set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);
328328

329-
330329
TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs);
331330

332331
bool NdarraySizeRel(const Array<Type>& types,

tests/python/relay/test_any.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,25 @@ def test_any_reshape():
9292
verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
9393
verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
9494

95+
def verify_any_argwhere(x_shape, x_np_shape, out_shape):
96+
x = relay.var('x', shape=x_shape, dtype="bool")
97+
y = relay.argwhere(x)
98+
mod = relay.module.Module()
99+
mod["main"] = relay.Function([x], y)
100+
data = np.random.choice([True, False], size=x_np_shape)
101+
for kind in ["debug", "vm"]:
102+
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
103+
result = ex.evaluate()(data).asnumpy()
104+
expected = np.argwhere(data)
105+
assert result.shape == expected.shape
106+
tvm.testing.assert_allclose(result.flatten(), expected.flatten())
107+
108+
def test_any_argwhere():
109+
verify_any_argwhere(any_dims(2), (5, 5), None)
110+
verify_any_argwhere(any_dims(3), (5, 5, 5), None)
111+
verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None)
112+
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), None)
113+
95114
def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
96115
mod = relay.Module()
97116
data = relay.var('data', shape=data_shape, dtype='float32')

topi/python/topi/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .transform import *
2323
from .broadcast import *
2424
from .sort import *
25+
from .where import *
2526
from . import nn
2627
from . import x86
2728
from . import cuda

topi/python/topi/generic/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .extern import *
2121
from .vision import *
2222
from .sort import *
23+
from .where import *

topi/python/topi/generic/where.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, no-member
18+
"""Generic vision operators"""
19+
from __future__ import absolute_import as _abs
20+
import tvm
21+
from .vision import _default_schedule
22+
23+
@tvm.target.generic_func
24+
def schedule_argwhere(outs):
25+
"""Schedule for argwhere operator.
26+
27+
Parameters
28+
----------
29+
outs: Array of Tensor
30+
The indices that would sort an input array along
31+
the given axis.
32+
33+
Returns
34+
-------
35+
s: Schedule
36+
The computation schedule for the op.
37+
"""
38+
return _default_schedule(outs, False)

0 commit comments

Comments
 (0)