Skip to content

Commit 7d71dd8

Browse files
MarisaKirisamevinx13
authored andcommitted
[Relay][Training] Add gradient for Crossentropy (#3925)
* save save redo max test save address comment fix * address comment * increase rtol * address review comment
1 parent 59d8d40 commit 7d71dd8

File tree

9 files changed

+143
-8
lines changed

9 files changed

+143
-8
lines changed

python/tvm/relay/op/_reduce.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ def _schedule_reduce(_, outs, target):
3636
_reg.register_schedule("prod", _schedule_reduce)
3737
_reg.register_schedule("mean", _schedule_reduce)
3838
_reg.register_schedule("variance", _schedule_reduce)
39+
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)

python/tvm/relay/op/_tensor_grad.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,26 @@
2525
from . import nn as _nn
2626
from .op import register_gradient
2727
from .reduce import sum as _sum
28-
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal
28+
from .tensor import (
29+
cos,
30+
exp,
31+
less,
32+
negative,
33+
ones_like,
34+
power,
35+
sin,
36+
zeros_like,
37+
equal,
38+
shape_of,
39+
log)
2940
from .transform import (
3041
broadcast_to_like,
3142
collapse_sum_like,
3243
cast_like,
3344
reshape,
3445
reshape_like,
3546
strided_slice,
47+
take,
3648
tile,
3749
transpose,
3850
where,
@@ -353,3 +365,12 @@ def sum_grad(orig, grad):
353365
"""Returns grad broadcasted to data dims"""
354366
data = orig.args[0]
355367
return [broadcast_to_like(grad, data)]
368+
369+
370+
@register_gradient("nn.cross_entropy")
371+
def cross_entropy_grad(orig, grad):
372+
x, y = orig.args
373+
shape = shape_of(x)
374+
batch_size = take(shape, const(0, dtype='int32'), axis=0)
375+
grad = grad / batch_size.astype('float32')
376+
return [-grad * y / x, -grad * log(x)]

python/tvm/relay/op/nn/_nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,12 @@ def schedule_bitserial_dense(attrs, outputs, target):
745745

746746

747747
reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
748+
749+
750+
reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
751+
752+
753+
@reg.register_compute("nn.cross_entropy")
754+
def compute_cross_entropy(attrs, inputs, out_dtype, target):
755+
x, y = inputs
756+
return [-topi.sum(topi.log(x) * y) / x.shape[0]]

python/tvm/relay/op/nn/nn.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,3 +1758,22 @@ def bitserial_dense(data,
17581758
"""
17591759
return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
17601760
pack_dtype, out_dtype, unipolar)
1761+
1762+
1763+
def cross_entropy(predictions, targets):
1764+
"""CrossEntropy without logits.
1765+
1766+
Parameters
1767+
----------
1768+
predictions : tvm.relay.Expr
1769+
The predictions.
1770+
1771+
targets : tvm.relay.Expr
1772+
The targets.
1773+
1774+
Returns
1775+
-------
1776+
result : tvm.relay.Expr
1777+
The computed result.
1778+
"""
1779+
return _make.cross_entropy(predictions, targets)

python/tvm/relay/testing/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def run_infer_type(expr):
5656
return run_opt_pass(expr, transform.InferType())
5757

5858

59-
def _np_randn_from_type(t, scale=1):
60-
return (scale * np.random.randn(*(int(d) for d in t.shape))).astype(t.dtype)
59+
def _np_randn_from_type(t, scale=1, mean=0):
60+
return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)
6161

6262

63-
def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3):
63+
def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0):
6464
"""Perform numerical gradient checking given a relay function.
6565
6666
Compare analytical gradients to numerical gradients derived from two-sided approximation. Note
@@ -86,15 +86,23 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3):
8686
The relative tolerance on difference between numerical and analytical gradients. Note that
8787
this needs to be scaled appropriately relative to the chosen eps.
8888
89+
scale: float
90+
The standard deviation of the inputs.
91+
92+
mean: float
93+
The mean of the inputs.
8994
"""
9095

9196
fwd_func = run_infer_type(func)
9297
bwd_func = run_infer_type(gradient(fwd_func))
9398

99+
if scale is None:
100+
scale = 10 * eps
101+
94102
if inputs is None:
95103
params = fwd_func.params
96104
# Generate random inputs on the same scale as epsilon to avoid numerical precision loss.
97-
inputs = [_np_randn_from_type(x.checked_type, scale=(10 * eps)) for x in params]
105+
inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params]
98106

99107
for target, ctx in ctx_list():
100108
intrp = relay.create_executor(ctx=ctx, target=target)

src/relay/op/nn/nn.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,5 +817,54 @@ are data in batch.
817817
.add_type_rel("BatchMatmul", BatchMatmulRel);
818818

819819

820+
// relay.nn.cross_entropy
821+
bool CrossEntropyRel(const Array<Type>& types,
822+
int num_inputs,
823+
const Attrs& attrs,
824+
const TypeReporter& reporter) {
825+
CHECK_EQ(types.size(), 3);
826+
const auto* x = types[0].as<TensorTypeNode>();
827+
const auto* y = types[1].as<TensorTypeNode>();
828+
if (x == nullptr || y == nullptr) return false;
829+
CHECK(x->shape.size() == 2 && y->shape.size() == 2)
830+
<< "CrossEntropy: shapes of x and y is inconsistent, "
831+
<< "x shape = " << x->shape << ", "
832+
<< "y shape = " << y->shape;
833+
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
834+
<< "CrossEntropy: shapes of x and y is inconsistent, "
835+
<< "x shape = " << x->shape << ", "
836+
<< "y shape = " << y->shape;
837+
CHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
838+
<< "CrossEntropy: shapes of x and y is inconsistent, "
839+
<< "x shape = " << x->shape << ", "
840+
<< "y shape = " << y->shape;
841+
// assign output type
842+
reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
843+
return true;
844+
}
845+
846+
// Positional relay function to create batch_matmul operator used by frontend FFI.
847+
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
848+
static const Op& op = Op::Get("nn.cross_entropy");
849+
return CallNode::make(op, {predictions, targets}, Attrs(), {});
850+
}
851+
852+
853+
TVM_REGISTER_API("relay.op.nn._make.cross_entropy")
854+
.set_body_typed(MakeCrossEntropy);
855+
856+
857+
RELAY_REGISTER_OP("nn.cross_entropy")
858+
.describe(R"code(
859+
Computes cross entropy given predictions and targets.
860+
Do log on the data - do not accept logits.
861+
)code" TVM_ADD_FILELINE)
862+
.set_num_inputs(2)
863+
.add_argument("x", "1D Tensor", "Predictions.")
864+
.add_argument("y", "1D Tensor", "Targets.")
865+
.set_support_level(10)
866+
.add_type_rel("CrossEntropy", CrossEntropyRel);
867+
868+
820869
} // namespace relay
821870
} // namespace tvm
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from tvm import relay
18+
from tvm.relay.testing import check_grad
19+
20+
21+
def test_cross_entropy_grad():
22+
x = relay.var("x", shape=(1, 5))
23+
y = relay.var("y", shape=(1, 5))
24+
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)
25+
26+
27+
if __name__ == "__main__":
28+
test_cross_entropy_grad()

tests/python/relay/test_op_grad_level4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ def test_sum_grad():
3232

3333

3434
def test_max_grad():
35-
s = (5, 10)
35+
s = (10, 10)
3636
t = relay.TensorType(s)
3737
x = relay.var("x", t)
3838
axis = 0
3939
z = relay.max(x, axis)
4040

4141
fwd_func = relay.Function([x], z)
42-
check_grad(fwd_func, eps=1e-7, rtol=1)
42+
check_grad(fwd_func, scale=1e-3)
4343

4444

4545
if __name__ == "__main__":

tests/python/relay/test_op_level5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def verify_resize(dshape, scale, method, layout):
6767
for kind in ["graph", "debug"]:
6868
intrp = relay.create_executor(kind, ctx=ctx, target=target)
6969
op_res = intrp.evaluate(func)(x_data)
70-
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
70+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
7171
for method in ["bilinear", "nearest_neighbor"]:
7272
for layout in ["NHWC", "NCHW"]:
7373
verify_resize((1, 4, 4, 4), 2, method, layout)

0 commit comments

Comments
 (0)