Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from topi.util import get_const_tuple
from topi.nn.util import get_pad_tuple
from ..expr import const, Tuple, TupleGetItem
from .op import register_gradient
from .transform import collapse_sum_like, broadcast_to_like, where
from .reduce import sum as _sum
from .transform import collapse_sum_like, broadcast_to_like, where, transpose, reshape, tile, \
strided_slice
from .tensor import exp, negative, power, less, cos, sin
from .tensor import zeros_like, ones_like
from . import nn as _nn
Expand Down Expand Up @@ -187,3 +191,62 @@ def concatenate_grad(orig, grad):
# Assume only two element in tuple rn.
# In the real implementation, concatenate_grad probably need to be implemented by an operator.
return [Tuple([zeros_like(x), zeros_like(y)])]

@register_gradient("nn.conv2d")
def conv2d_grad(orig, grad):
"""Gradient of conv2d"""
attrs = orig.attrs
data, weight = orig.args
data_shape = get_const_tuple(data.checked_type.shape)
weight_shape = get_const_tuple(weight.checked_type.shape)
_, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape)
batch, in_channel, in_h, in_w = data_shape
out_channel, _, filter_h, filter_w = weight_shape

# infer output_padding
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(get_const_tuple(attrs.padding),
(filter_h, filter_w))
stride_h, stride_w = get_const_tuple(attrs.strides)
dilation_h, dilation_w = get_const_tuple(attrs.dilation)
out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w
output_padding = (in_h - out_h, in_w - out_w)

assert attrs.data_layout == 'NCHW', 'only support NCHW data layout'
assert attrs.kernel_layout == 'OIHW', 'only support OIHW kernel layout'
assert attrs.out_layout in ['', 'NCHW'], 'only support NCHW output layout'


backward_data = _nn.conv2d_transpose(grad, weight,
strides=attrs.strides,
padding=attrs.padding,
dilation=attrs.dilation,
groups=attrs.groups,
output_padding=output_padding)
grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw

backward_weight = _nn.conv2d(data, grad,
strides=attrs.dilation,
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch)
# infer shape of backward_weight
padded_weight_grad_h = (in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom) \
// dilation_h + 1
padded_weight_grad_w = (in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right) \
// dilation_w + 1
backward_weight = reshape(backward_weight,
[batch, in_channel // attrs.groups, out_channel,
padded_weight_grad_h, padded_weight_grad_w])
backward_weight = _sum(backward_weight, axis=0)
backward_weight = transpose(backward_weight, [1, 0, 2, 3])

assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(backward_weight, begin=[0, 0, 0, 0],
end=[None, None, filter_h, filter_w])

return [backward_data, backward_weight]
6 changes: 5 additions & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def conv2d_transpose(data,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
output_padding=(0, 0),
out_dtype=""):
"""Two dimensional transposed convolution operator.
Expand Down Expand Up @@ -152,6 +153,9 @@ def conv2d_transpose(data,
kernel_layout : str, optional
Layout of the weight.

out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout

output_padding : Tuple[int], optional
Additional zero-padding to be added to one side of the output.

Expand All @@ -165,7 +169,7 @@ def conv2d_transpose(data,
"""
return _make.conv2d_transpose(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, output_padding, out_dtype)
kernel_layout, out_layout, output_padding, out_dtype)


def softmax(data, axis=-1):
Expand Down
2 changes: 2 additions & 0 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ Expr MakeConv2DTranspose(Expr data,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
Array<IndexExpr> output_padding,
DataType out_dtype) {
auto attrs = make_node<Conv2DTransposeAttrs>();
Expand All @@ -319,6 +320,7 @@ Expr MakeConv2DTranspose(Expr data,
attrs->groups = groups;
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv2d_transpose");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
Expand Down
36 changes: 25 additions & 11 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,22 @@ struct ADTensor : ADValueNode {
Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& forward) :
forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { }
forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) {
this->forward->checked_type_ = forward->checked_type();
}
};

/*! \brief A staged representation of the program, we reflect
* Relay functions into a function over fragments of AD. We
* can compute away this function to obtain a reverse mode program.
*/
struct ADFunction : ADValueNode {
std::function<ADValue(const std::vector<ADValue>&,
std::function<ADValue(const Type&,
const std::vector<ADValue>&,
const Attrs&,
const tvm::Array<Type>&)> func;
explicit ADFunction(const std::function<ADValue(const std::vector<ADValue>&,
explicit ADFunction(const std::function<ADValue(const Type&,
const std::vector<ADValue>&,
const Attrs&,
const tvm::Array<Type>&)>& func) :
func(func) { }
Expand All @@ -139,14 +143,16 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
Op op_ref = GetRef<Op>(op);
CHECK(rev_map.count(op_ref))
<< op->name << " does not have reverse mode defined";
return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& args,
return std::make_shared<ADFunction>([this, op_ref](const Type& orig_type,
const std::vector<ADValue>& args,
const Attrs& attrs,
const tvm::Array<Type>& type_args) {
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().forward);
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
orig->checked_type_ = orig_type;
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
Expand All @@ -171,13 +177,14 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
for (const auto& arg : op->args) {
args.push_back(VisitExpr(arg));
}
return f->get<ADFunction>().func(args, op->attrs, op->type_args);
return f->get<ADFunction>().func(op->checked_type(), args, op->attrs, op->type_args);
}

ADValue VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
// todo: assert no closure
return std::make_shared<ADFunction>([this, f](const std::vector<ADValue>& args,
return std::make_shared<ADFunction>([this, f](const Type& orig_type,
const std::vector<ADValue>& args,
const Attrs& attrs,
const tvm::Array<Type>& type_args) {
CHECK_EQ(f->params.size(), args.size());
Expand Down Expand Up @@ -227,7 +234,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
for (const auto& p : f->params) {
args.push_back(std::make_shared<ADTensor>(ll, p));
}
auto c = rev->get<ADFunction>().func(args, Attrs(), {});
auto c = rev->get<ADFunction>().func(f->checked_type(), args, Attrs(), {});
const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OnesLike(res.forward);
Expand Down Expand Up @@ -271,7 +278,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (t.as<TensorTypeNode>()) {
return f(e);
auto ret = f(e);
ret->checked_type_ = t;
return ret;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarisaKirisame please review this change

} else if (auto* tt = t.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < tt->fields.size(); ++i) {
Expand All @@ -280,7 +289,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
ll->Push(GetField(e, i)),
ll));
}
return TupleNode::make(fields);
auto ret = TupleNode::make(fields);
ret->checked_type_ = t;
return std::move(ret);
} else {
LOG(FATAL) << "unsupported input/output type: " << tt;
throw;
Expand Down Expand Up @@ -348,11 +359,14 @@ struct ReverseAD : ExprMutator {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> orig_args;
for (size_t i = 0; i < args.size(); ++i) {
for (size_t i = 0; i < args.size(); i++) {
orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
}
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
auto ret = ll->Push(GetRev(op->checked_type(), ll->Push(orig), ll));
orig->checked_type_ = op->checked_type();
Var orig_var = ll->Push(orig);
orig_var->checked_type_ = op->checked_type();
auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
Expand Down
49 changes: 48 additions & 1 deletion tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import topi.testing
from tvm import relay
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, check_grad
from tvm.relay.testing import run_infer_type


Expand Down Expand Up @@ -83,6 +83,53 @@ def test_avg_pool2d_grad():
ceil_mode=False, count_include_pad=False)


def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
try:
import torch
import torch.nn.functional as F
except ImportError:
print('Skip because pytorch is not installed')
return

dtype = 'float32'
data = relay.var('data', shape=dshape, dtype=dtype)
weight = relay.var('weight', shape=wshape, dtype=dtype)
conv = relay.nn.conv2d(data, weight, strides=strides, padding=padding, dilation=dilation,
groups=groups)
fwd_func = relay.Function([data, weight], conv)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func, mode=mode))

data_pt = torch.randn(*dshape, dtype=torch.float32, requires_grad=True)
weight_pt = torch.randn(*wshape, dtype=torch.float32, requires_grad=True)
out_pt = F.conv2d(data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation,
groups=groups)
grad_output_pt = torch.ones(out_pt.shape)
grad_input_pt = F.grad.conv2d_input(dshape, weight_pt, grad_output_pt, stride=strides,
padding=padding, dilation=dilation, groups=groups) \
.detach().numpy()
grad_weight_pt = F.grad.conv2d_weight(data_pt, wshape, grad_output_pt, stride=strides,
padding=padding, dilation=dilation, groups=groups) \
.detach().numpy()


for target, ctx in ctx_list():
data = tvm.nd.array(data_pt.detach().numpy(), ctx)
weight = tvm.nd.array(weight_pt.detach().numpy(), ctx)
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (grad_input, grad_weight) = intrp.evaluate(bwd_func)(data, weight)
np.testing.assert_allclose(grad_input.asnumpy(), grad_input_pt, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(grad_weight.asnumpy(), grad_weight_pt, rtol=1e-4, atol=1e-4)


def test_conv2d_grad():
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1])
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [1, 1], [0, 0], [1, 1])
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [2, 2], [0, 0], [1, 1])
verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order')


if __name__ == "__main__":
test_max_pool2d_grad()
test_avg_pool2d_grad()
test_conv2d_grad()