Skip to content

Commit 00c56e2

Browse files
do it
1 parent d3244c0 commit 00c56e2

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/python/relay/test_pass_gradient.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,20 @@ def test_tuple():
156156
def test_pow():
157157
mod = relay.Module()
158158
p = Prelude(mod)
159-
shape = ()
159+
shape = (10, 10)
160160
dtype = 'float32'
161161
t = relay.TensorType(shape, dtype)
162162
x = relay.var("x", t)
163163
double = relay.Function([x], x + x)
164164
i = relay.var("i", t)
165-
func = relay.Function([x], p.iterate)
165+
func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i]))
166+
back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod)
167+
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
168+
i_nd = rand(dtype, *shape)
169+
ex = create_executor(mod=mod)
170+
forward, (grad_i,) = ex.evaluate(back_func)(i_nd)
171+
np.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
172+
np.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
166173

167174
if __name__ == "__main__":
168175
test_id()

0 commit comments

Comments
 (0)