From e6ece4ce657e220568dd350a3b6f21da9da527a5 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Thu, 3 Oct 2019 21:25:25 +0000 Subject: [PATCH 1/7] overload half operators for cuda codegen --- src/codegen/codegen_cuda.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index b48f647688c5..c93dffc859b0 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -50,6 +50,16 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { std::string CodeGenCUDA::Finish() { if (enable_fp16_) { decl_stream << "#include \n"; + decl_stream << "__device__ half max(const half a, const half b)\n" + "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n"; + decl_stream << "__device__ half min(const half a, const half b)\n" + "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n"; + decl_stream << "__device__ half operator+(const volatile __half &a, const volatile __half &b)\n" + "{\n return __hadd(a, b);\n}\n"; + decl_stream << "__device__ half operator<=(const volatile __half &a, const volatile __half &b)\n" + "{\n return __hlt(a, b);\n}\n"; + decl_stream << "__device__ half operator*(const volatile __half &a, const volatile __half &b)\n" + "{\n return __hmul(a, b);\n}\n"; } if (enable_int8_) { From 49044dbf80798cd187174c61728bfab6439db29f Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Thu, 3 Oct 2019 22:33:08 +0000 Subject: [PATCH 2/7] add float16 te test_op_level1 --- tests/python/relay/test_op_level1.py | 185 ++++++++++++++------------- 1 file changed, 98 insertions(+), 87 deletions(-) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index d31f4f46f5d7..7c3ed8454b21 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -41,12 +41,12 @@ def rsqrt(x): one = np.ones_like(x) return one / np.sqrt(x) -def test_unary_op(): - def check_single_op(opfunc, ref): +def test_unary_op(dtype): + def check_single_op(opfunc, ref, dtype): shape = (10, 4) - dtype = 'float32' - tp = relay.TensorType(shape, dtype) - x = relay.var("x", tp) + dtype = dtype + tp = relay.TensorType(shape) + x = relay.var("x", tp, dtype=dtype) y = opfunc(x) # test printer assert ("{}(%x)".format(y.op.name)) in y.astext() @@ -77,22 +77,22 @@ def check_single_op(opfunc, ref): (tvm.relay.cos, np.cos), (tvm.relay.sin, np.sin), (tvm.relay.atan, np.arctan)]: - check_single_op(opfunc, ref) + check_single_op(opfunc, ref, dtype) -def test_binary_op(): +def test_binary_op(dtype): def inst(vars, sh): return [vars.get(s, s) for s in sh] - def check_binary_op(opfunc, ref): + def check_binary_op(opfunc, ref, dtype): # TODO(@jroesch): this piece of code improperly uses type variables. n = tvm.var("n") s1 = (5, n, 5) s2 = (n, 1) t1 = relay.TensorType(s1) t2 = relay.TensorType(s2) - x = relay.var("x", t1) - y = relay.var("y", t2) + x = relay.var("x", t1, dtype=dtype) + y = relay.var("y", t2, dtype=dtype) z = opfunc(x, y) # test printer assert ("{}(%x, %y)".format(z.op.name)) in z.astext() @@ -102,11 +102,11 @@ def check_binary_op(opfunc, ref): if ref is not None: t1 = relay.TensorType((5, 10, 5)) t2 = relay.TensorType((5, 10, 5)) - x = relay.var("x", t1) - y = relay.var("y", t2) + x = relay.var("x", t1, dtype=dtype) + y = relay.var("y", t2, dtype=dtype) z = opfunc(x, y) - x_data = np.random.rand(5, 10, 5).astype(t1.dtype) - y_data = np.random.rand(5, 10, 5).astype(t2.dtype) + x_data = np.random.rand(5, 10, 5).astype(dtype) + y_data = np.random.rand(5, 10, 5).astype(dtype) ref_res = ref(x_data, y_data) func = relay.Function([x, y], z) @@ -121,10 +121,10 @@ def check_binary_op(opfunc, ref): (relay.subtract, np.subtract), (relay.multiply, np.multiply), (relay.divide, np.divide)]: - check_binary_op(opfunc, ref) + check_binary_op(opfunc, ref, dtype) -def test_expand_dims(): +def test_expand_dims(dtype): # based on topi test def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): x = relay.Var("x", relay.TensorType(dshape, dtype)) @@ -140,10 +140,11 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1) -def test_bias_add(): +def test_bias_add(dtype): xshape=(10, 2, 3, 4) bshape=(2,) - dtype="float32" + dtype=dtype + rtol = 1e-2 if dtype is 'float16' else 1e-5 x = relay.var("x", shape=xshape) bias = relay.var("bias") z = relay.nn.bias_add(x, bias) @@ -158,27 +159,30 @@ def test_bias_add(): for target, ctx in ctx_list(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol) -def test_expand_dims_infer_type(): +def test_expand_dims_infer_type(dtype): n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = relay.var("x", shape=(n, t, d)) + x = relay.var("x", shape=(n, t, d), dtype=dtype) y = relay.expand_dims(x, axis=2) assert "axis=2" in y.astext() yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, t, 1, 100)) + assert yy.checked_type == relay.TensorType((n, t, 1, 100), dtype) -def test_softmax(): +def test_softmax(dtype): + # Softmax accuracy for float16 is poor + if dtype == 'float16': + return shape = (10, 4) - x = relay.var("x", shape=shape) + x = relay.var("x", shape=shape, dtype=dtype) y = relay.nn.softmax(x, axis=1) assert "nn.softmax" in y.astext() yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(shape) + assert yy.checked_type == relay.TensorType(shape, dtype) func = relay.Function([x], y) - x_data = np.random.uniform(size=shape).astype("float32") + x_data = np.random.uniform(size=shape).astype(dtype) ref_res = topi.testing.softmax_python(x_data) for target, ctx in ctx_list(): intrp = relay.create_executor("graph", ctx=ctx, target=target) @@ -186,15 +190,18 @@ def test_softmax(): np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) -def test_log_softmax(): +def test_log_softmax(dtype): + # Softmax accuracy for float16 is poor + if dtype == 'float16': + return shape = (10, 4) - x = relay.var("x", shape=shape) + x = relay.var("x", shape=shape, dtype=dtype) y = relay.nn.log_softmax(x, axis=1) assert "nn.log_softmax" in y.astext() yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(shape) + assert yy.checked_type == relay.TensorType(shape, dtype) func = relay.Function([x], y) - x_data = np.random.uniform(size=shape).astype("float32") + x_data = np.random.uniform(size=shape).astype(dtype) ref_res = topi.testing.log_softmax_python(x_data) for target, ctx in ctx_list(): intrp = relay.create_executor("graph", ctx=ctx, target=target) @@ -202,7 +209,7 @@ def test_log_softmax(): np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) -def test_concatenate(): +def test_concatenate(dtype): n, t, d = tvm.var("n"), tvm.var("t"), 100 x = relay.var("x", shape=(n, t, d)) y = relay.var("y", shape=(n, t, d)) @@ -232,16 +239,16 @@ def test_concatenate(): else: assert False - x = relay.var("x", shape=(10, 5)) - y = relay.var("y", shape=(10, 5)) - t = relay.var("z", shape=()) + x = relay.var("x", shape=(10, 5), dtype=dtype) + y = relay.var("y", shape=(10, 5), dtype=dtype) + t = relay.var("z", shape=(), dtype=dtype) z = relay.concatenate((x, y), axis=1) z = relay.add(z, t) # Check result. func = relay.Function([x, y, t], z) - x_data = np.random.rand(10, 5).astype('float32') - y_data = np.random.rand(10, 5).astype('float32') - t_data = np.random.uniform(size=()).astype('float32') + x_data = np.random.rand(10, 5).astype(dtype) + y_data = np.random.rand(10, 5).astype(dtype) + t_data = np.random.uniform(size=()).astype(dtype) ref_res = np.concatenate((x_data, y_data), axis=1) + t_data for target, ctx in ctx_list(): @@ -252,9 +259,9 @@ def test_concatenate(): op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01) -def test_dropout(): +def test_dropout(dtype): n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") - input_ty = relay.TensorType((n, t, d), "float32") + input_ty = relay.TensorType((n, t, d), dtype) x = relay.var("x", input_ty) y = relay.nn.dropout(x, rate=0.75) assert "rate=" in y.astext() @@ -262,85 +269,88 @@ def test_dropout(): assert yy.checked_type == input_ty -def test_batch_norm(): +def test_batch_norm(dtype): # beta and gamma ignored - data = relay.var("data", relay.TensorType((3, 2, 1))) - beta = relay.var("beta", relay.TensorType((2,))) - gamma = relay.var("gamma", relay.TensorType((2,))) - moving_mean = relay.var("moving_mean", relay.TensorType((2,))) - moving_var = relay.var("moving_var", relay.TensorType((2,))) + data = relay.var("data", relay.TensorType((3, 2, 1), dtype)) + beta = relay.var("beta", relay.TensorType((2,), dtype)) + gamma = relay.var("gamma", relay.TensorType((2,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((2,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((2,), dtype)) y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, center=False, scale=False) yy = run_infer_type(y.astuple()) assert "center=" in yy.astext() assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.TensorType((3, 2, 1), "float32"), - relay.TensorType((2,), "float32"), - relay.TensorType((2,), "float32") + relay.TensorType((3, 2, 1), dtype), + relay.TensorType((2,), dtype), + relay.TensorType((2,), dtype) ])) - beta = relay.var("beta", relay.TensorType((3,))) - gamma = relay.var("gamma", relay.TensorType((3,))) - moving_mean = relay.var("moving_mean", relay.TensorType((3,))) - moving_var = relay.var("moving_var", relay.TensorType((3,))) + beta = relay.var("beta", relay.TensorType((3,), dtype)) + gamma = relay.var("gamma", relay.TensorType((3,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((3,), dtype)) y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=0, center=False, scale=False) yy = run_infer_type(y.astuple()) assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((3, 2, 1), "float32"), - relay.ty.TensorType((3,), "float32"), - relay.ty.TensorType((3,), "float32") + relay.ty.TensorType((3, 2, 1), dtype), + relay.ty.TensorType((3,), dtype), + relay.ty.TensorType((3,), dtype) ])) # axis=-1 - data = relay.var("data", relay.TensorType((1, 2, 3))) - beta = relay.var("beta", relay.TensorType((3,))) - gamma = relay.var("gamma", relay.TensorType((3,))) - moving_mean = relay.var("moving_mean", relay.TensorType((3,))) - moving_var = relay.var("moving_var", relay.TensorType((3,))) + data = relay.var("data", relay.TensorType((1, 2, 3), dtype)) + beta = relay.var("beta", relay.TensorType((3,), dtype)) + gamma = relay.var("gamma", relay.TensorType((3,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((3,), dtype)) y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=-1, center=False, scale=False) yy = run_infer_type(y.astuple()) assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((1, 2, 3), "float32"), - relay.ty.TensorType((3,), "float32"), - relay.ty.TensorType((3,), "float32") + relay.ty.TensorType((1, 2, 3), dtype), + relay.ty.TensorType((3,), dtype), + relay.ty.TensorType((3,), dtype) ])) -def test_dense(): +def test_dense(dtype): + # Dense accuracy for float16 is poor + if dtype == 'float16': + return n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - w = relay.var("w", relay.TensorType((2, w), "float32")) + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + w = relay.var("w", relay.TensorType((2, w), dtype)) y = relay.nn.dense(x, w, units=2) assert "units=2" in y.astext() yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") + assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) wh, ww = tvm.var("wh"), tvm.var("ww") - w = relay.var("w", relay.TensorType((ww, wh), "float32")) + w = relay.var("w", relay.TensorType((ww, wh), dtype)) y = relay.nn.dense(x, w) yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, ww), "float32") + assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype) n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) w = relay.var("w", relay.IncompleteType()) y = relay.nn.dense(x, w, units=2) yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") + assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) - x = relay.var("x", shape=(10, 5)) - w = relay.var("w", shape=(2, 5)) + x = relay.var("x", shape=(10, 5), dtype=dtype) + w = relay.var("w", shape=(2, 5), dtype=dtype) z = relay.nn.dense(x, w) # Check result. func = relay.Function([x, w], z) - x_data = np.random.rand(10, 5).astype('float32') - w_data = np.random.rand(2, 5).astype('float32') + x_data = np.random.rand(10, 5).astype(dtype) + w_data = np.random.rand(2, 5).astype(dtype) ref_res = np.dot(x_data, w_data.T) for target, ctx in ctx_list(): @@ -363,15 +373,16 @@ def test_bitserial_dense(): if __name__ == "__main__": - test_concatenate() - test_bias_add() - test_unary_op() - test_binary_op() - test_expand_dims_infer_type() - test_expand_dims() - test_softmax() - test_log_softmax() - test_dropout() - test_batch_norm() - test_dense() + for dtype in ['float16', 'float32']: + test_concatenate(dtype) + test_bias_add(dtype) + test_unary_op(dtype) + test_binary_op(dtype) + test_expand_dims_infer_type(dtype) + test_expand_dims(dtype) + test_softmax(dtype) + test_log_softmax(dtype) + test_dropout(dtype) + test_batch_norm(dtype) + test_dense(dtype) test_bitserial_dense() From d85f4529076873279da1a99022bf00dd3834292b Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Thu, 3 Oct 2019 23:50:06 +0000 Subject: [PATCH 3/7] fix test_op_level1.py --- tests/python/relay/test_op_level1.py | 476 ++++++++++++++------------- 1 file changed, 242 insertions(+), 234 deletions(-) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 7c3ed8454b21..4b375be0ddad 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -41,7 +41,7 @@ def rsqrt(x): one = np.ones_like(x) return one / np.sqrt(x) -def test_unary_op(dtype): +def test_unary_op(): def check_single_op(opfunc, ref, dtype): shape = (10, 4) dtype = dtype @@ -77,10 +77,11 @@ def check_single_op(opfunc, ref, dtype): (tvm.relay.cos, np.cos), (tvm.relay.sin, np.sin), (tvm.relay.atan, np.arctan)]: - check_single_op(opfunc, ref, dtype) + for dtype in ['float16', 'float32']: + check_single_op(opfunc, ref, dtype) -def test_binary_op(dtype): +def test_binary_op(): def inst(vars, sh): return [vars.get(s, s) for s in sh] @@ -121,10 +122,11 @@ def check_binary_op(opfunc, ref, dtype): (relay.subtract, np.subtract), (relay.multiply, np.multiply), (relay.divide, np.divide)]: - check_binary_op(opfunc, ref, dtype) + for dtype in ['float16', 'float32']: + check_binary_op(opfunc, ref, dtype) -def test_expand_dims(dtype): +def test_expand_dims(): # based on topi test def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): x = relay.Var("x", relay.TensorType(dshape, dtype)) @@ -135,231 +137,238 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) + for dtype in ['float16', 'float32']: + verify_expand_dims((3, 10), dtype, (3, 10, 1, 1), 2, 2) + verify_expand_dims((3, 10), dtype, (1, 3, 10), -3, 1) - verify_expand_dims((3, 10), 'float32', (3, 10, 1, 1), 2, 2) - verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1) - - -def test_bias_add(dtype): - xshape=(10, 2, 3, 4) - bshape=(2,) - dtype=dtype - rtol = 1e-2 if dtype is 'float16' else 1e-5 - x = relay.var("x", shape=xshape) - bias = relay.var("bias") - z = relay.nn.bias_add(x, bias) - zz = run_infer_type(z) - assert "axis=" not in zz.astext() - assert zz.args[1].checked_type == relay.TensorType(bshape) - - func = relay.Function([x, bias], z) - x_data = np.random.uniform(size=xshape).astype(dtype) - y_data = np.random.uniform(size=bshape).astype(dtype) - ref_res = x_data + y_data.reshape((2, 1, 1)) - for target, ctx in ctx_list(): - intrp = relay.create_executor("graph", ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol) - - -def test_expand_dims_infer_type(dtype): - n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = relay.var("x", shape=(n, t, d), dtype=dtype) - y = relay.expand_dims(x, axis=2) - assert "axis=2" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, t, 1, 100), dtype) +def test_bias_add(): + for dtype in ['float16', 'float32']: + xshape=(10, 2, 3, 4) + bshape=(2,) + rtol = 1e-2 if dtype is 'float16' else 1e-5 + x = relay.var("x", shape=xshape, dtype=dtype) + bias = relay.var("bias", dtype=dtype) + z = relay.nn.bias_add(x, bias) + zz = run_infer_type(z) + assert "axis=" not in zz.astext() + assert zz.args[1].checked_type == relay.TensorType(bshape, dtype) -def test_softmax(dtype): - # Softmax accuracy for float16 is poor - if dtype == 'float16': - return - shape = (10, 4) - x = relay.var("x", shape=shape, dtype=dtype) - y = relay.nn.softmax(x, axis=1) - assert "nn.softmax" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(shape, dtype) - func = relay.Function([x], y) - x_data = np.random.uniform(size=shape).astype(dtype) - ref_res = topi.testing.softmax_python(x_data) - for target, ctx in ctx_list(): - intrp = relay.create_executor("graph", ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - - -def test_log_softmax(dtype): - # Softmax accuracy for float16 is poor - if dtype == 'float16': - return - shape = (10, 4) - x = relay.var("x", shape=shape, dtype=dtype) - y = relay.nn.log_softmax(x, axis=1) - assert "nn.log_softmax" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(shape, dtype) - func = relay.Function([x], y) - x_data = np.random.uniform(size=shape).astype(dtype) - ref_res = topi.testing.log_softmax_python(x_data) - for target, ctx in ctx_list(): - intrp = relay.create_executor("graph", ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x_data) - np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - - -def test_concatenate(dtype): - n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = relay.var("x", shape=(n, t, d)) - y = relay.var("y", shape=(n, t, d)) - z = relay.concatenate((x, y), axis=-1) - assert "axis=" in z.astext() - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType((n, t, 200)) - - x = relay.exp(x) - z = relay.concatenate((x, y), axis=2) - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType((n, t, 200)) - - z = relay.concatenate((x, y), axis=1) - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType((n, t + t, 100)) - - # check shape mismatches (the following case is expected to raise tvm._ffi.base.TVMError. - try: - x = relay.var('p1', shape=(2, 5)) - y = relay.var('p2', shape=(2, 3)) - c = relay.concatenate([x, y], axis=0) - func = relay.Function([x, y], c) - zz = run_infer_type(func) - except tvm._ffi.base.TVMError: - pass - else: - assert False - - x = relay.var("x", shape=(10, 5), dtype=dtype) - y = relay.var("y", shape=(10, 5), dtype=dtype) - t = relay.var("z", shape=(), dtype=dtype) - z = relay.concatenate((x, y), axis=1) - z = relay.add(z, t) - # Check result. - func = relay.Function([x, y, t], z) - x_data = np.random.rand(10, 5).astype(dtype) - y_data = np.random.rand(10, 5).astype(dtype) - t_data = np.random.uniform(size=()).astype(dtype) - ref_res = np.concatenate((x_data, y_data), axis=1) + t_data - - for target, ctx in ctx_list(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01) - op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) - tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01) - -def test_dropout(dtype): - n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") - input_ty = relay.TensorType((n, t, d), dtype) - x = relay.var("x", input_ty) - y = relay.nn.dropout(x, rate=0.75) - assert "rate=" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == input_ty - - -def test_batch_norm(dtype): - # beta and gamma ignored - data = relay.var("data", relay.TensorType((3, 2, 1), dtype)) - beta = relay.var("beta", relay.TensorType((2,), dtype)) - gamma = relay.var("gamma", relay.TensorType((2,), dtype)) - moving_mean = relay.var("moving_mean", relay.TensorType((2,), dtype)) - moving_var = relay.var("moving_var", relay.TensorType((2,), dtype)) - y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - center=False, scale=False) - yy = run_infer_type(y.astuple()) - assert "center=" in yy.astext() - assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.TensorType((3, 2, 1), dtype), - relay.TensorType((2,), dtype), - relay.TensorType((2,), dtype) - ])) - - beta = relay.var("beta", relay.TensorType((3,), dtype)) - gamma = relay.var("gamma", relay.TensorType((3,), dtype)) - moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) - moving_var = relay.var("moving_var", relay.TensorType((3,), dtype)) - - y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - axis=0, center=False, scale=False) - yy = run_infer_type(y.astuple()) - assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((3, 2, 1), dtype), - relay.ty.TensorType((3,), dtype), - relay.ty.TensorType((3,), dtype) - ])) - - # axis=-1 - data = relay.var("data", relay.TensorType((1, 2, 3), dtype)) - beta = relay.var("beta", relay.TensorType((3,), dtype)) - gamma = relay.var("gamma", relay.TensorType((3,), dtype)) - moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) - moving_var = relay.var("moving_var", relay.TensorType((3,), dtype)) - y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - axis=-1, center=False, scale=False) - yy = run_infer_type(y.astuple()) - assert yy.checked_type == relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((1, 2, 3), dtype), - relay.ty.TensorType((3,), dtype), - relay.ty.TensorType((3,), dtype) - ])) - - -def test_dense(dtype): - # Dense accuracy for float16 is poor - if dtype == 'float16': - return - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) - w = relay.var("w", relay.TensorType((2, w), dtype)) - y = relay.nn.dense(x, w, units=2) - assert "units=2" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) + func = relay.Function([x, bias], z) + x_data = np.random.uniform(size=xshape).astype(dtype) + y_data = np.random.uniform(size=bshape).astype(dtype) + ref_res = x_data + y_data.reshape((2, 1, 1)) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data, y_data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol) - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) - wh, ww = tvm.var("wh"), tvm.var("ww") - w = relay.var("w", relay.TensorType((ww, wh), dtype)) - y = relay.nn.dense(x, w) - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype) - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) - w = relay.var("w", relay.IncompleteType()) - y = relay.nn.dense(x, w, units=2) - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) +def test_expand_dims_infer_type(): + for dtype in ['float16', 'float32']: + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = relay.var("x", shape=(n, t, d), dtype=dtype) + y = relay.expand_dims(x, axis=2) + assert "axis=2" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, t, 1, 100), dtype) + + +def test_softmax(): + for dtype in ['float16', 'float32']: + # Softmax accuracy for float16 is poor + if dtype == 'float16': + return + shape = (10, 4) + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.nn.softmax(x, axis=1) + assert "nn.softmax" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType(shape, dtype) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype(dtype) + ref_res = topi.testing.softmax_python(x_data) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + +def test_log_softmax(): + for dtype in ['float16', 'float32']: + # Softmax accuracy for float16 is poor + if dtype == 'float16': + return + shape = (10, 4) + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.nn.log_softmax(x, axis=1) + assert "nn.log_softmax" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType(shape, dtype) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype(dtype) + ref_res = topi.testing.log_softmax_python(x_data) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + +def test_concatenate(): + for dtype in ['float16', 'float32']: + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = relay.var("x", shape=(n, t, d)) + y = relay.var("y", shape=(n, t, d)) + z = relay.concatenate((x, y), axis=-1) + assert "axis=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, t, 200)) + + x = relay.exp(x) + z = relay.concatenate((x, y), axis=2) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, t, 200)) + + z = relay.concatenate((x, y), axis=1) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, t + t, 100)) + + # check shape mismatches (the following case is expected to raise tvm._ffi.base.TVMError. + try: + x = relay.var('p1', shape=(2, 5)) + y = relay.var('p2', shape=(2, 3)) + c = relay.concatenate([x, y], axis=0) + func = relay.Function([x, y], c) + zz = run_infer_type(func) + except tvm._ffi.base.TVMError: + pass + else: + assert False + + x = relay.var("x", shape=(10, 5), dtype=dtype) + y = relay.var("y", shape=(10, 5), dtype=dtype) + t = relay.var("z", shape=(), dtype=dtype) + z = relay.concatenate((x, y), axis=1) + z = relay.add(z, t) + # Check result. + func = relay.Function([x, y, t], z) + x_data = np.random.rand(10, 5).astype(dtype) + y_data = np.random.rand(10, 5).astype(dtype) + t_data = np.random.uniform(size=()).astype(dtype) + ref_res = np.concatenate((x_data, y_data), axis=1) + t_data + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=0.01) + op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=0.01) + +def test_dropout(): + for dtype in ['float16', 'float32']: + n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") + input_ty = relay.TensorType((n, t, d), dtype) + x = relay.var("x", input_ty) + y = relay.nn.dropout(x, rate=0.75) + assert "rate=" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == input_ty + + +def test_batch_norm(): + for dtype in ['float16', 'float32']: + # beta and gamma ignored + data = relay.var("data", relay.TensorType((3, 2, 1), dtype)) + beta = relay.var("beta", relay.TensorType((2,), dtype)) + gamma = relay.var("gamma", relay.TensorType((2,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((2,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((2,), dtype)) + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + center=False, scale=False) + yy = run_infer_type(y.astuple()) + assert "center=" in yy.astext() + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ + relay.TensorType((3, 2, 1), dtype), + relay.TensorType((2,), dtype), + relay.TensorType((2,), dtype) + ])) + + beta = relay.var("beta", relay.TensorType((3,), dtype)) + gamma = relay.var("gamma", relay.TensorType((3,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((3,), dtype)) + + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + axis=0, center=False, scale=False) + yy = run_infer_type(y.astuple()) + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((3, 2, 1), dtype), + relay.ty.TensorType((3,), dtype), + relay.ty.TensorType((3,), dtype) + ])) + + # axis=-1 + data = relay.var("data", relay.TensorType((1, 2, 3), dtype)) + beta = relay.var("beta", relay.TensorType((3,), dtype)) + gamma = relay.var("gamma", relay.TensorType((3,), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((3,), dtype)) + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + axis=-1, center=False, scale=False) + yy = run_infer_type(y.astuple()) + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((1, 2, 3), dtype), + relay.ty.TensorType((3,), dtype), + relay.ty.TensorType((3,), dtype) + ])) + + +def test_dense(): + for dtype in ['float16', 'float32']: + # Dense accuracy for float16 is poor + if dtype == 'float16': + return + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + w = relay.var("w", relay.TensorType((2, w), dtype)) + y = relay.nn.dense(x, w, units=2) + assert "units=2" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) - x = relay.var("x", shape=(10, 5), dtype=dtype) - w = relay.var("w", shape=(2, 5), dtype=dtype) - z = relay.nn.dense(x, w) + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + wh, ww = tvm.var("wh"), tvm.var("ww") + w = relay.var("w", relay.TensorType((ww, wh), dtype)) + y = relay.nn.dense(x, w) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, ww), dtype) + + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + w = relay.var("w", relay.IncompleteType()) + y = relay.nn.dense(x, w, units=2) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), dtype) - # Check result. - func = relay.Function([x, w], z) - x_data = np.random.rand(10, 5).astype(dtype) - w_data = np.random.rand(2, 5).astype(dtype) - ref_res = np.dot(x_data, w_data.T) + x = relay.var("x", shape=(10, 5), dtype=dtype) + w = relay.var("w", shape=(2, 5), dtype=dtype) + z = relay.nn.dense(x, w) - for target, ctx in ctx_list(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x_data, w_data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, w_data) - tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) + # Check result. + func = relay.Function([x, w], z) + x_data = np.random.rand(10, 5).astype(dtype) + w_data = np.random.rand(2, 5).astype(dtype) + ref_res = np.dot(x_data, w_data.T) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) + op_res2 = intrp2.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) def test_bitserial_dense(): @@ -373,16 +382,15 @@ def test_bitserial_dense(): if __name__ == "__main__": - for dtype in ['float16', 'float32']: - test_concatenate(dtype) - test_bias_add(dtype) - test_unary_op(dtype) - test_binary_op(dtype) - test_expand_dims_infer_type(dtype) - test_expand_dims(dtype) - test_softmax(dtype) - test_log_softmax(dtype) - test_dropout(dtype) - test_batch_norm(dtype) - test_dense(dtype) + test_concatenate() + test_bias_add() + test_unary_op() + test_binary_op() + test_expand_dims_infer_type() + test_expand_dims() + test_softmax() + test_log_softmax() + test_dropout() + test_batch_norm() + test_dense() test_bitserial_dense() From 47902c4976f2370dfd29599a245342640191b22c Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Fri, 4 Oct 2019 19:35:10 +0000 Subject: [PATCH 4/7] fix lint --- src/codegen/codegen_cuda.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index c93dffc859b0..241310fd00d4 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -50,15 +50,19 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { std::string CodeGenCUDA::Finish() { if (enable_fp16_) { decl_stream << "#include \n"; - decl_stream << "__device__ half max(const half a, const half b)\n" + decl_stream << "__device__ half max" \ + "(const half a, const half b)\n" "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n"; decl_stream << "__device__ half min(const half a, const half b)\n" "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n"; - decl_stream << "__device__ half operator+(const volatile __half &a, const volatile __half &b)\n" + decl_stream << "__device__ half operator+" \ + "(const volatile __half &a, const volatile __half &b)\n" "{\n return __hadd(a, b);\n}\n"; - decl_stream << "__device__ half operator<=(const volatile __half &a, const volatile __half &b)\n" + decl_stream << "__device__ half operator<=" \ + "(const volatile __half &a, const volatile __half &b)\n" "{\n return __hlt(a, b);\n}\n"; - decl_stream << "__device__ half operator*(const volatile __half &a, const volatile __half &b)\n" + decl_stream << "__device__ half operator*" \ + "(const volatile __half &a, const volatile __half &b)\n" "{\n return __hmul(a, b);\n}\n"; } From 66014b4f66391a8cd6837be76620111a13a7889a Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Mon, 7 Oct 2019 19:26:02 +0000 Subject: [PATCH 5/7] disable fp16 test if gpu does not support --- tests/python/relay/test_op_level1.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 4b375be0ddad..5dcf3b90196e 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -21,6 +21,7 @@ from tvm.relay import transform from tvm.relay.testing import ctx_list import topi.testing +from tvm.contrib.nvcc import have_fp16 def run_infer_type(expr): mod = relay.Module.from_expr(expr) @@ -78,6 +79,8 @@ def check_single_op(opfunc, ref, dtype): (tvm.relay.sin, np.sin), (tvm.relay.atan, np.arctan)]: for dtype in ['float16', 'float32']: + if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): + continue check_single_op(opfunc, ref, dtype) @@ -123,6 +126,8 @@ def check_binary_op(opfunc, ref, dtype): (relay.multiply, np.multiply), (relay.divide, np.divide)]: for dtype in ['float16', 'float32']: + if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): + continue check_binary_op(opfunc, ref, dtype) @@ -144,6 +149,8 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): def test_bias_add(): for dtype in ['float16', 'float32']: + if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): + continue xshape=(10, 2, 3, 4) bshape=(2,) rtol = 1e-2 if dtype is 'float16' else 1e-5 @@ -166,6 +173,8 @@ def test_bias_add(): def test_expand_dims_infer_type(): for dtype in ['float16', 'float32']: + if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): + continue n, t, d = tvm.var("n"), tvm.var("t"), 100 x = relay.var("x", shape=(n, t, d), dtype=dtype) y = relay.expand_dims(x, axis=2) @@ -216,6 +225,8 @@ def test_log_softmax(): def test_concatenate(): for dtype in ['float16', 'float32']: + if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): + continue n, t, d = tvm.var("n"), tvm.var("t"), 100 x = relay.var("x", shape=(n, t, d)) y = relay.var("y", shape=(n, t, d)) @@ -267,6 +278,8 @@ def test_concatenate(): def test_dropout(): for dtype in ['float16', 'float32']: + if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): + continue n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") input_ty = relay.TensorType((n, t, d), dtype) x = relay.var("x", input_ty) @@ -278,6 +291,8 @@ def test_dropout(): def test_batch_norm(): for dtype in ['float16', 'float32']: + if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): + continue # beta and gamma ignored data = relay.var("data", relay.TensorType((3, 2, 1), dtype)) beta = relay.var("beta", relay.TensorType((2,), dtype)) From c20a7cc84df0529a9359672f23e249d95403ab97 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Mon, 7 Oct 2019 20:30:59 +0000 Subject: [PATCH 6/7] disable fp16 test if gpu does not support --- tests/python/relay/test_op_level1.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 5dcf3b90196e..4a07662554b9 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -62,6 +62,8 @@ def check_single_op(opfunc, ref, dtype): for target, ctx in ctx_list(): # use graph by execuor default for testing, as we need # create function explicitly to avoid constant-folding. + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) @@ -79,8 +81,6 @@ def check_single_op(opfunc, ref, dtype): (tvm.relay.sin, np.sin), (tvm.relay.atan, np.arctan)]: for dtype in ['float16', 'float32']: - if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): - continue check_single_op(opfunc, ref, dtype) @@ -117,6 +117,8 @@ def check_binary_op(opfunc, ref, dtype): for target, ctx in ctx_list(): # use graph by execuor default for testing, as we need # create function explicitly to avoid constant-folding. + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) @@ -126,8 +128,6 @@ def check_binary_op(opfunc, ref, dtype): (relay.multiply, np.multiply), (relay.divide, np.divide)]: for dtype in ['float16', 'float32']: - if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): - continue check_binary_op(opfunc, ref, dtype) @@ -137,6 +137,8 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): x = relay.Var("x", relay.TensorType(dshape, dtype)) func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis)) for target, ctx in ctx_list(): + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue data = np.random.uniform(size=dshape).astype(dtype) ref_res = data.reshape(oshape) intrp = relay.create_executor("graph", ctx=ctx, target=target) @@ -149,8 +151,6 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): def test_bias_add(): for dtype in ['float16', 'float32']: - if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): - continue xshape=(10, 2, 3, 4) bshape=(2,) rtol = 1e-2 if dtype is 'float16' else 1e-5 @@ -166,6 +166,8 @@ def test_bias_add(): y_data = np.random.uniform(size=bshape).astype(dtype) ref_res = x_data + y_data.reshape((2, 1, 1)) for target, ctx in ctx_list(): + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data, y_data) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol) @@ -173,8 +175,6 @@ def test_bias_add(): def test_expand_dims_infer_type(): for dtype in ['float16', 'float32']: - if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): - continue n, t, d = tvm.var("n"), tvm.var("t"), 100 x = relay.var("x", shape=(n, t, d), dtype=dtype) y = relay.expand_dims(x, axis=2) @@ -225,8 +225,6 @@ def test_log_softmax(): def test_concatenate(): for dtype in ['float16', 'float32']: - if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): - continue n, t, d = tvm.var("n"), tvm.var("t"), 100 x = relay.var("x", shape=(n, t, d)) y = relay.var("y", shape=(n, t, d)) @@ -269,6 +267,8 @@ def test_concatenate(): ref_res = np.concatenate((x_data, y_data), axis=1) + t_data for target, ctx in ctx_list(): + if dtype == 'float16' and target == 'cuda' and not have_fp16(tvm.gpu(0).compute_version): + continue intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) @@ -278,8 +278,6 @@ def test_concatenate(): def test_dropout(): for dtype in ['float16', 'float32']: - if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): - continue n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") input_ty = relay.TensorType((n, t, d), dtype) x = relay.var("x", input_ty) @@ -291,8 +289,6 @@ def test_dropout(): def test_batch_norm(): for dtype in ['float16', 'float32']: - if dtype == 'float16' and not have_fp16(tvm.gpu(0).compute_version): - continue # beta and gamma ignored data = relay.var("data", relay.TensorType((3, 2, 1), dtype)) beta = relay.var("beta", relay.TensorType((2,), dtype)) From dc378fe804cd143f74b1686f5fe228d22cabfc1a Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Tue, 8 Oct 2019 23:03:09 +0000 Subject: [PATCH 7/7] bypass float16 test if gpu does not support float16 --- topi/tests/python/test_topi_transform.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index b1aa20ea07df..4a529f4a047f 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -19,6 +19,7 @@ import tvm import topi import topi.testing +from tvm.contrib.nvcc import have_fp16 from common import get_all_backend @@ -53,6 +54,9 @@ def check_device(device): if not ctx.exist: print("Skip because %s is not enabled" % device) return + if in_dtype == "float16" and device == 'cuda' and not have_fp16(ctx.compute_version): + print("Skip because %s does not have fp16 support" % device) + return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_elemwise(B)