Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,15 @@ def args_to_workload(x, topi_compute_func=None):
workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
workload = x
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use '
'primitive types only' % type(x))
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types or tvm.expr.Var only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload

def template(func):
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/autotvm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_const_int(exp):


def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
"""Verifies input tuple is IntImm or Var, returns tuple of int or Var.

Parameters
----------
Expand All @@ -175,4 +175,14 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int
The output.
"""
return tuple(get_const_int(x) for x in in_tuple)
ret = []
for elem in in_tuple:
if isinstance(elem, expr.Var):
ret.append(elem)
elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
elem = ir_pass.Simplify(elem)
if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
ret.append(elem)
else:
ret.append(get_const_int(elem))
return tuple(ret)
68 changes: 68 additions & 0 deletions python/tvm/relay/op/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from __future__ import absolute_import

import topi

from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ...api import convert
from ...hybrid import script


def _schedule_reduce(_, outs, target):
Expand All @@ -39,3 +43,67 @@ def _schedule_reduce(_, outs, target):
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)


def _create_axis_record(attrs, inputs):
axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
exclude = get_const_int(attrs.exclude) > 0
keepdims = get_const_int(attrs.keepdims) > 0
data_shape = inputs[0]
shape_size = data_shape.shape[0].value
axis_record = [-1] * shape_size
if axes is None:
axes = list(range(shape_size))

for i, axis in enumerate(axes):
if axis < 0:
axes[i] = shape_size + axis

if exclude:
ex_axes = []
for i in range(shape_size):
if i not in axes:
ex_axes.append(i)
axes = ex_axes

for i in range(shape_size):
if i not in axes:
axis_record[i] = i

if not keepdims:
tmp = []
for i in axis_record:
if i >= 0:
tmp.append(i)
axis_record = tmp

return axis_record


@script
def _reduce_shape_func(data_shape, axis_record):
out = output_tensor((len(axis_record),), "int64")
for i in const_range(len(axis_record)):
if axis_record[i] >= 0:
out[i] = data_shape[axis_record[i]]
else:
out[i] = int64(1)

return out

def reduce_shape_func(attrs, inputs, _):
"""
Shape function for reduce op.
"""
axis_record = _create_axis_record(attrs, inputs)
return [_reduce_shape_func(inputs[0], convert(axis_record))]

_reg.register_shape_func("argmax", False, reduce_shape_func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I get it correctly, all function in this file has reduce_schedule and reduce_shape_func.
can we use a for loop instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registering schedule/shapefunc line by line should be cleaner when later we introduce some ops which require different sch/shapefunc implementation. Also this way is adopted across all relay op files. I think it's a good idea to keep them consistent

_reg.register_shape_func("argmin", False, reduce_shape_func)
_reg.register_shape_func("all", False, reduce_shape_func)
_reg.register_shape_func("sum", False, reduce_shape_func)
_reg.register_shape_func("max", False, reduce_shape_func)
_reg.register_shape_func("min", False, reduce_shape_func)
_reg.register_shape_func("prod", False, reduce_shape_func)
_reg.register_shape_func("mean", False, reduce_shape_func)
_reg.register_shape_func("variance", False, reduce_shape_func)
25 changes: 12 additions & 13 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,6 @@ def _cast_shape_function(x):
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

@script
def _expand_dims_shape_func(x):
ndim = len(x.shape)
out = output_tensor((ndim+1,), "int64")
out[0] = int64(1)
for i in const_range(0, ndim):
out[i+1] = int64(x.shape[i])
return out

def expand_dims_shape_func(attrs, inputs, out_ndims):
return [_expand_dims_shape_func(*inputs)]

# shape func
@script
def _broadcast_shape_func(x, y, ndim):
Expand Down Expand Up @@ -161,9 +149,17 @@ def _broadcast_shape_func(x, y, ndim):
return out

def broadcast_shape_func(attrs, inputs, out_ndims):
"""
Shape function for broadcast op.
"""
return [_broadcast_shape_func(*inputs, out_ndims[0])]

register_shape_func("expand_dims", False, expand_dims_shape_func)
def elemwise_shape_func(attrs, inputs, _):
"""
Shape function for elemwise op.
"""
return [topi.math.identity(inputs[0])]

register_shape_func("cast", False, cast_shape_func)

register_shape_func("add", False, broadcast_shape_func)
Expand All @@ -179,3 +175,6 @@ def broadcast_shape_func(attrs, inputs, out_ndims):
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)

register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
194 changes: 193 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import
import tvm
import topi
Expand Down Expand Up @@ -303,3 +303,195 @@ def compute_argwhere(attrs, inputs, output_type, _):
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]

@script
def _layout_transform_shape_func(data_shape,
out_layout_len,
dst_equal_list,
dst_mul_list,
dst_div_list,
dst_mix_list):
out = output_tensor((out_layout_len,), "int64")
for i in const_range(len(dst_equal_list)):
out[dst_equal_list[i][0]] = data_shape[dst_equal_list[i][1]]
for i in const_range(len(dst_mul_list)):
out[dst_mul_list[i][0]] = data_shape[dst_mul_list[i][1]] * \
data_shape[dst_mul_list[i][2]]
for i in const_range(len(dst_div_list)):
out[dst_div_list[i][0]] = data_shape[dst_div_list[i][1]] \
// dst_div_list[i][3]
out[dst_div_list[i][2]] = int64(dst_div_list[i][3])
for i in const_range(len(dst_mix_list)):
out[dst_mix_list[i][0]] = data_shape[dst_mix_list[i][1]] * \
dst_mix_list[i][2] // dst_mix_list[i][4]
out[dst_mix_list[i][3]] = int64(dst_mix_list[i][4])

return out

@_reg.register_shape_func("layout_transform", False)
def layout_transform_shape_func(attrs, inputs, _):
"""
Shape function for layout_transform op.
"""
def _fetch_axis(layout):
major_axes = []
minor_axes = {}
num_start = -1
for i, item in enumerate(layout):
if "A" <= item <= "Z":
major_axes.append(item)
elif "a" <= item <= "z":
last_num = int(layout[num_start:i])
minor_axes[item] = last_num
num_start = -1
elif num_start < 0:
num_start = i
return major_axes, minor_axes

_, src_minor_axes = _fetch_axis(attrs.src_layout)
dst_major_axes, dst_minor_axes = _fetch_axis(attrs.dst_layout)
src_letter_list = []
dst_letter_list = []
for item in attrs.src_layout:
if "A" <= item <= "Z" or "a" <= item <= "z":
src_letter_list.append(item)
for item in attrs.dst_layout:
if "A" <= item <= "Z" or "a" <= item <= "z":
dst_letter_list.append(item)
out_layout_len = len(dst_major_axes) + len(dst_minor_axes)
dst_equal_list = []
dst_mul_list = []
dst_div_list = []
dst_mix_list = []

for key in dst_major_axes:
if key.lower() not in dst_minor_axes:
if key.lower() not in src_minor_axes:
dst_equal_list.append((dst_letter_list.index(key),
src_letter_list.index(key)))
else:
dst_mul_list.append((dst_letter_list.index(key),
src_letter_list.index(key),
src_letter_list.index(key.lower())))
else:
if key.lower() not in src_minor_axes:
dst_div_list.append((dst_letter_list.index(key),
src_letter_list.index(key),
dst_letter_list.index(key.lower()),
dst_minor_axes[key.lower()]))
else:
dst_mix_list.append((dst_letter_list.index(key),
src_letter_list.index(key),
src_minor_axes[key.lower()],
dst_letter_list.index(key.lower()),
dst_minor_axes[key.lower()]))

return [_layout_transform_shape_func(inputs[0],
convert(out_layout_len),
convert(dst_equal_list),
convert(dst_mul_list),
convert(dst_div_list),
convert(dst_mix_list))]

@script
def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):
out = output_tensor((ndim + num_newaxis,), "int64")
for i in const_range(out.shape[0]):
if i < axis:
out[i] = data_shape[i]
elif i < axis + num_newaxis:
out[i] = int64(1)
else:
out[i] = data_shape[i - num_newaxis]

return out

@_reg.register_shape_func("expand_dims", False)
def expand_dim_shape_func(attrs, inputs, _):
"""
Shape function for expand_dim op.
"""
axis = get_const_int(attrs.axis)
num_newaxis = get_const_int(attrs.num_newaxis)
if axis < 0:
axis = inputs[0].shape[0] + axis + 1
ndim = inputs[0].shape[0] if inputs[0].shape else 0
return [_expand_dim_shape_func(inputs[0],
convert(ndim),
convert(axis),
convert(num_newaxis))]

@script
def _transpose_shape_func(data_shape, axes):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(len(axes)):
out[i] = data_shape[axes[i]]

return out

@_reg.register_shape_func("transpose", False)
def transpose_shape_func(attrs, inputs, _):
"""
Shape function for transpose op.
"""
axes = attrs.axes if attrs.axes is None else get_const_tuple(attrs.axes)
if axes is None:
axes = list(range(inputs[0].shape[0].value))
axes.reverse()
for i, axis in enumerate(axes):
if axis < 0:
axes[i] = inputs[0].shape[0] - axis
return [_transpose_shape_func(inputs[0], convert(axes))]

@script
def _squeeze_shape_func(data_shape, keep_axes):
out = output_tensor((len(keep_axes),), "int64")
if len(keep_axes) == 0:
out_size = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out_size += 1

if out_size == 0:
out_size = 1
out = output_tensor((out_size,), "int64")
out[0] = int64(1)
pos = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out[pos] = data_shape[i]
pos += 1
else:
for i in const_range(len(keep_axes)):
out[i] = data_shape[keep_axes[i]]

return out

@_reg.register_shape_func("squeeze", False)
def squeeze_shape_func(attrs, inputs, _):
"""
Shape function for squeeze op.
"""
axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis)
keep_axes = []
if axis is not None:
for i in range(inputs[0].shape[0].value):
if i not in axis:
keep_axes.append(i)

return [_squeeze_shape_func(inputs[0], convert(keep_axes))]

@script
def _reshape_like_shape_func(target_shape):
out = output_tensor((target_shape.shape[0],), "int64")
for i in const_range(target_shape.shape[0]):
out[i] = target_shape[i]

return out

@_reg.register_shape_func("reshape_like", False)
def reshape_like_shape_func(attrs, inputs, _):
"""
Shape function for reshape_like op.
"""
return [_reshape_like_shape_func(inputs[1])]
Loading