|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 | """Backend compiler related feature registration""" |
18 | | -# pylint: disable=invalid-name,unused-argument, len-as-condition |
| 18 | +# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks |
19 | 19 | from __future__ import absolute_import |
20 | 20 | from topi.util import get_const_int, get_const_tuple |
21 | 21 | from . import op as _reg |
@@ -204,3 +204,68 @@ def take_shape_func(attrs, inputs, out_ndims): |
204 | 204 | axis += data_ndim |
205 | 205 | assert 0 <= axis < data_ndim |
206 | 206 | return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] |
| 207 | + |
| 208 | +@script |
| 209 | +def _argwhere_shape_func_2d(condition): |
| 210 | + out = output_tensor((2, ), "int64") |
| 211 | + out[0] = int64(0) |
| 212 | + out[1] = int64(2) |
| 213 | + for i1 in range(condition.shape[0]): |
| 214 | + for i2 in range(condition.shape[1]): |
| 215 | + if condition[i1, i2]: |
| 216 | + out[0] += int64(1) |
| 217 | + return out |
| 218 | + |
| 219 | +@script |
| 220 | +def _argwhere_shape_func_3d(condition): |
| 221 | + out = output_tensor((2, ), "int64") |
| 222 | + out[0] = int64(0) |
| 223 | + out[1] = int64(3) |
| 224 | + for i1 in range(condition.shape[0]): |
| 225 | + for i2 in range(condition.shape[1]): |
| 226 | + for i3 in range(condition.shape[2]): |
| 227 | + if condition[i1, i2, i3]: |
| 228 | + out[0] += int64(1) |
| 229 | + return out |
| 230 | + |
| 231 | +@script |
| 232 | +def _argwhere_shape_func_4d(condition): |
| 233 | + out = output_tensor((2, ), "int64") |
| 234 | + out[0] = int64(0) |
| 235 | + out[1] = int64(4) |
| 236 | + for i1 in range(condition.shape[0]): |
| 237 | + for i2 in range(condition.shape[1]): |
| 238 | + for i3 in range(condition.shape[2]): |
| 239 | + for i4 in range(condition.shape[3]): |
| 240 | + if condition[i1, i2, i3, i4]: |
| 241 | + out[0] += int64(1) |
| 242 | + return out |
| 243 | + |
| 244 | +@script |
| 245 | +def _argwhere_shape_func_5d(condition): |
| 246 | + out = output_tensor((2, ), "int64") |
| 247 | + out[0] = int64(0) |
| 248 | + out[1] = int64(5) |
| 249 | + for i1 in range(condition.shape[0]): |
| 250 | + for i2 in range(condition.shape[1]): |
| 251 | + for i3 in range(condition.shape[2]): |
| 252 | + for i4 in range(condition.shape[3]): |
| 253 | + for i5 in range(condition.shape[4]): |
| 254 | + if condition[i1, i2, i3, i4, i5]: |
| 255 | + out[0] += int64(1) |
| 256 | + return out |
| 257 | + |
| 258 | +@_reg.register_shape_func("argwhere", True) |
| 259 | +def argwhere_shape_func(attrs, inputs, out_ndims): |
| 260 | + """ |
| 261 | + Shape function for argwhere. |
| 262 | + """ |
| 263 | + if len(inputs[0].shape) == 2: |
| 264 | + return [_argwhere_shape_func_2d(inputs[0])] |
| 265 | + elif len(inputs[0].shape) == 3: |
| 266 | + return [_argwhere_shape_func_3d(inputs[0])] |
| 267 | + elif len(inputs[0].shape) == 4: |
| 268 | + return [_argwhere_shape_func_4d(inputs[0])] |
| 269 | + elif len(inputs[0].shape) == 5: |
| 270 | + return [_argwhere_shape_func_5d(inputs[0])] |
| 271 | + return [] |
0 commit comments