Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def elemwise_shape_func(attrs, inputs, _):


register_shape_func("cast", False, elemwise_shape_func)
register_shape_func("cast_like", False, elemwise_shape_func)
register_shape_func("zeros", False, no_data_full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", False, no_data_full_shape_func)
Expand Down
141 changes: 140 additions & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@

from tvm.topi.nn.util import get_pad_tuple
from tvm.topi.util import get_const_tuple
from tvm.error import OpError

from ..expr import Tuple, TupleGetItem, const
from ..expr import Tuple, TupleGetItem, const, Var
from ..ty import TensorType
from ..loops import while_loop
from . import nn as _nn
from .op import register_gradient
from .reduce import sum as _sum
Expand All @@ -40,6 +43,7 @@
equal,
shape_of,
log,
concatenate,
)
from .transform import (
broadcast_to_like,
Expand All @@ -55,6 +59,10 @@
repeat,
expand_dims,
full_like,
split,
squeeze,
strided_set,
arange,
)


Expand Down Expand Up @@ -665,3 +673,134 @@ def cross_entropy_with_logits_grad(orig, grad):
batch_size = take(shape, const(0, dtype="int32"), axis=0)
grad = grad / batch_size.astype(x.checked_type.dtype)
return [-grad * y, -grad * x]


@register_gradient("take")
def take_grad(orig, grad):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can get by by defining a 'put' operator, that put a scalar into an index of a tensor, and leave other palces unchanged. put and take has some classic property which I assume will be better for the optimizer. It also allow other optimization (e.g. put and reduce_sum, using grad + (put vala at idxa in 0_array) + (put valb at idxb in 0_array) will be collapsed into a long chain of put on grad, allowing COW to kick in and all take grad mutation update (instead of creating another tensor).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jroesch please look and comment as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point that I was wondering about. The loop is basically just implementing a put operation (like I described in the comment), so it would make sense to have it be a separate op since I imagine it will be useful in general. Should I remove this gradient for now, or keep it and replace it with put once I implement it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are fine.

"""
Returns the gradient of take.
"""

def make_scalar_tensor(v):
if isinstance(v, int):
v = const(v, dtype="int32")
return reshape(v, (1,))

# TODO(@altanh): we currently assume indices are in range
data, indices = orig.args
axis = orig.attrs.axis
zero, one = map(make_scalar_tensor, [0, 1])
data_grad = zeros_like(data)
try:
data_shape = data.checked_type.concrete_shape
except TypeError as ty_err:
raise OpError("currently take_grad only supports data with concrete shape") from ty_err
if axis is None:
axis = 0
data_grad = reshape(data_grad, (-1,))
data_shape = 1
for dim in data.checked_type.concrete_shape:
data_shape *= dim
data_shape = (data_shape,)
else:
axis = int(axis)
strides = [1] * len(data_shape)

if len(indices.checked_type.shape) == 0:
# axis on grad has been squeezed in this case
num_indices = one
indices = reshape(indices, (1,))
grad = expand_dims(grad, int(axis))
elif len(indices.checked_type.shape) == 1:
num_indices = take(shape_of(indices), zero, axis=0)
else:
raise OpError("take_grad only supports scalar or 1D indices")

def loop_cond(data_grad, i):
return squeeze(less(i, num_indices))

def loop_body(data_grad, i):
index = take(indices, i, axis=0)
grad_slice = take(grad, i, axis=axis)
begin, end = [], []
for ax, size in enumerate(data_shape):
size = make_scalar_tensor(size)
begin.append(zero if ax != axis else index)
end.append(size if ax != axis else index + one)
begin, end = concatenate(begin, axis=0), concatenate(end, axis=0)
# data_grad[:,...,index at axis,...,:] += grad_slice
update = strided_slice(data_grad, begin, end, strides=strides)
update = update + grad_slice # no need to expand grad_slice since i has shape (1,)
next_data_grad = strided_set(data_grad, update, begin, end, strides=strides)
return (next_data_grad, i + one)

loop_vars = [
Var("data_grad", type_annotation=TensorType(data_shape, data.checked_type.dtype)),
Var("i", type_annotation=TensorType((1,), "int32")),
]

loop = while_loop(loop_cond, loop_vars, loop_body)
result = loop(data_grad, zero)
data_grad = TupleGetItem(result, 0)

if orig.attrs.axis is None:
data_grad = reshape_like(data_grad, data)

return [data_grad, zeros_like(orig.args[1])]


@register_gradient("contrib_reverse_reshape")
def reverse_reshape_grad(orig, grad):
"""
Returns the gradient of reverse_reshape (same as reshape).
"""
return [reshape_like(grad, orig.args[0])]


@register_gradient("stack")
def stack_grad(orig, grad):
"""
Returns grad split across stacked inputs.
"""
stack_axis = int(orig.attrs.axis)
sections = len(orig.args[0].checked_type.fields)
splits = split(grad, sections, stack_axis)
splits = Tuple([squeeze(x, axis=[stack_axis]) for x in splits])
return [splits]


@register_gradient("squeeze")
def squeeze_grad(orig, grad):
"""
Returns grad expanded to input size.
"""
# this should work, can't use expand_dims since we lose
# squeeze information when axis=None
return [reshape_like(grad, orig.args[0])]


@register_gradient("expand_dims")
def expand_dims_grad(orig, grad):
"""
Returns grad squeezed on expanded dims.
"""
axis = int(orig.attrs.axis)
for _ in range(orig.attrs.num_newaxis):
grad = squeeze(grad, axis=[axis])
return [grad]


@register_gradient("arange")
def arange_grad(orig, grad):
"""
Returns the gradient of arange.
"""
start, stop, step = orig.args
length = take(shape_of(orig), const(0, dtype="int32"), axis=0)

grad_start = cast_like(_sum(grad), start)
grad_stop = zeros_like(stop)
grad_step = cast_like(arange(length, dtype="int32"), grad) * grad
grad_step = cast_like(_sum(grad_step), step)

return [grad_start, grad_stop, grad_step]
6 changes: 5 additions & 1 deletion python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def run_infer_type(expr):


def _np_randn_from_type(t, scale=1, mean=0):
return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)
res = mean + (scale * np.random.randn(*(int(d) for d in t.shape)))
# if t.shape == (), then randn returns a scalar so we need to wrap for dtype conversion
if np.isscalar(res):
res = np.array(res)
return res.astype(t.dtype)


def check_grad(
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ bool TypeSolver::Solve() {

rnode->resolved = resolved;
} catch (const Error& err) {
this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << "err");
this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what());
rnode->resolved = false;
} catch (const dmlc::Error& e) {
ICHECK(false) << e.what();
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,11 @@ def test_bias_add_grad():
verify_bias_add((4, 8), (8,))


def test_expand_dims_grad():
data = relay.var("data", shape=(2, 3), dtype="float64")
fwd_func = relay.Function([data], relay.expand_dims(data, axis=1, num_newaxis=2))
check_grad(fwd_func)


if __name__ == "__main__":
pytest.main([__file__])
5 changes: 5 additions & 0 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,10 @@ def test_batch_matmul_grad():
check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y)))


def test_reverse_reshape_grad():
x = relay.var("x", shape=(3, 4, 5), dtype="float64")
check_grad(relay.Function([x], relay.op.reverse_reshape(x, (-1, 0))))


if __name__ == "__main__":
pytest.main([__file__])
44 changes: 43 additions & 1 deletion tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tvm
from tvm import te
from tvm import relay
from tvm.relay.testing import check_grad, run_infer_type
from tvm.relay.testing import check_grad, run_infer_type, _np_randn_from_type
from tvm.relay.transform import gradient
import tvm.testing

Expand Down Expand Up @@ -75,5 +75,47 @@ def test_copy_grad():
check_grad(fwd_func)


def test_take_grad():
data_dtype = relay.TensorType((3, 4, 5), "float64")
data = relay.var("data", data_dtype)
indices = relay.var("indices", relay.TensorType((relay.Any(),), "int32"))
inputs = [_np_randn_from_type(data_dtype, scale=1e-5), np.array([1, 2], dtype="int32")]
test_inputs = [inputs[0]]

# take on axis
fwd_func = relay.Function([data, indices], relay.take(data, indices, axis=1))
check_grad(fwd_func, inputs=inputs, test_inputs=test_inputs)

# take on flattened
fwd_func = relay.Function([data, indices], relay.take(data, indices, axis=None))
check_grad(fwd_func, inputs=inputs, test_inputs=test_inputs)


def test_stack_grad():
args = [relay.var(c, shape=(2, 3, 4), dtype="float64") for c in "xyz"]
fwd_func = relay.Function(args, relay.stack(args, axis=0))
check_grad(fwd_func)


def test_squeeze_grad():
data = relay.var("data", shape=(2, 1, 1, 3, 4, 1), dtype="float64")
fwd_func = relay.Function([data], relay.squeeze(data))
fwd_func_subset = relay.Function([data], relay.squeeze(data, axis=[1, -1]))
check_grad(fwd_func)
check_grad(fwd_func_subset)


def test_arange_grad():
# TODO: testing arange numerically is strange because two-sided approx can
# produce different output shapes
dtype = "float64"
start = relay.var("start", relay.TensorType((), dtype))
stop = relay.var("stop", relay.TensorType((), dtype))
step = relay.var("step", relay.TensorType((), dtype))
values = [np.array(v, dtype=dtype) for v in [2.5, 9.5, 1.8]]
fwd_func = relay.Function([start, stop, step], relay.arange(start, stop, step, dtype))
check_grad(fwd_func, inputs=values)


if __name__ == "__main__":
pytest.main()