Skip to content

Commit cc773a1

Browse files
kaitingwangwweic
authored andcommitted
[Relay] Clip gradient: grad * (select(x < min || max < x, 0, 1)) (apache#3509)
1 parent 03e5458 commit cc773a1

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

python/tvm/relay/op/_tensor_grad.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,15 @@ def abs_grad(orig, grad):
118118
zeros = zeros_like(x)
119119
ones = ones_like(x)
120120
return [where(less(x, zeros), -ones * grad, ones * grad)]
121+
122+
@register_gradient("clip")
123+
def clip_grad(orig, grad):
124+
"""Returns grad * (select(x < min || max < x , 0, 1))."""
125+
x = orig.args[0]
126+
a_min = orig.attrs.get_int("a_min")
127+
a_max = orig.attrs.get_int("a_max")
128+
a_mins = broadcast_to_like(const(a_min), x)
129+
a_maxs = broadcast_to_like(const(a_max), x)
130+
zeros = zeros_like(x)
131+
ones = ones_like(x)
132+
return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import numpy as np
18+
import tvm
19+
from tvm import relay
20+
from tvm.relay.transform import gradient
21+
from tvm.relay.testing import ctx_list
22+
23+
24+
def run_infer_type(expr):
25+
mod = relay.Module.from_expr(expr)
26+
mod = relay.transform.InferType()(mod)
27+
return mod["main"]
28+
29+
def test_clip():
30+
ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
31+
np.where(x < 1.0, np.zeros_like(x), np.ones_like(x))))
32+
x = relay.var("x", relay.TensorType((10, 4), "float32"))
33+
y = tvm.relay.clip(x, 1.0, 10.0)
34+
35+
data = np.random.rand(10, 4).astype("float32") * 11.0
36+
ref_grad = ref(data)
37+
fwd_func = relay.Function([x], y)
38+
bwd_func = run_infer_type(gradient(fwd_func))
39+
40+
for target, ctx in ctx_list():
41+
intrp = relay.create_executor(ctx=ctx, target=target)
42+
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
43+
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
44+
45+
46+
if __name__ == "__main__":
47+
test_clip()

0 commit comments

Comments
 (0)