diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 38b39b5fc27c..84b631cf3823 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1057,6 +1057,13 @@ def _has_cpu_feat(features): ) +requires_arm_fp16 = Feature( + "arm_fp16", + "Arm(R) Neon(TM) instructions for FP16", + run_time_check=lambda: _has_cpu_feat("fullfp16"), +) + + requires_aarch64_sve = Feature( "arm_sve", "AArch64 SVE", diff --git a/tests/python/topi/test_topi_conv2d_nhwc.py b/tests/python/topi/test_topi_conv2d_nhwc.py index 6ff844de088f..b5c9518d3419 100644 --- a/tests/python/topi/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/test_topi_conv2d_nhwc.py @@ -53,7 +53,7 @@ topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack, ), ( - "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a", + "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+fullfp16", topi.arm_cpu.compute_conv2d_NHWC_hybrid, topi.arm_cpu.schedule_conv2d_NHWC_hybrid, ), @@ -64,7 +64,7 @@ ), ) -dtype = tvm.testing.parameter("float32") +dtype = tvm.testing.parameter("float16", "float32") batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( # Pad M, N, K @@ -104,14 +104,36 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padd a_shape = (batch, in_height, in_width, in_channel) w_shape = (kernel, kernel, in_channel, num_filter) + np.random.seed(0) a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) - b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + + # scipy.signal.convolve2d does not support float16 data types, + # and the python fallback would be too slow for general use. + conv_dtype = "float32" if dtype == "float16" else dtype + b_np = tvm.topi.testing.conv2d_nhwc_python( + a_np.astype(conv_dtype), dw_np.astype(conv_dtype), stride, padding + ).astype(dtype) return a_np, w_np, b_np -def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilation): +def get_tolerance(dtype, w_np, b_np): + if dtype == "float16": + # A summation in float16 with a single accumulator very + # quickly runs into large rounding errors. + # This tolerance is necessary to ensure no false negatives, + # but it may introduce false positives, depending on schedule behaviour. + num_values_summed = w_np.shape[0] * w_np.shape[1] * w_np.shape[2] + next_float_gap_size = np.nextafter(b_np.max(), np.inf, dtype=b_np.dtype) - b_np.max() + tol = {"rtol": 1e-5, "atol": num_values_summed * next_float_gap_size / 2} + else: + tol = {"rtol": 1e-5, "atol": 1e-7} + + return tol + + +def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation): a_np, w_np, b_np = ref_data A = te.placeholder(a_np.shape, name="A", dtype=dtype) @@ -130,14 +152,21 @@ def test_conv2d_nhwc_gemm_fp32(device, ref_data, dtype, stride, padding, dilatio # Run only on AArch64 devices # Do not run SVE schedules on non-SVE devices - build_only = platform.machine() != "aarch64" or ( - target.features.has_sve and not tvm.testing.requires_aarch64_sve.run_time_check() + build_only = ( + platform.machine() != "aarch64" + or (target.features.has_sve and not tvm.testing.requires_aarch64_sve.run_time_check()) + or ( + dtype == "float16" + and target.features.has_fp16_simd + and not tvm.testing.requires_arm_fp16.run_time_check() + ) ) if build_only: return func(a, w, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + tol = get_tolerance(dtype, w_np, b_np) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"]) def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilation): @@ -155,7 +184,8 @@ def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilatio b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) func = tvm.build(s, [A, W, B], target) func(a, w, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + tol = get_tolerance(dtype, w_np, b_np) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"]) def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation): @@ -184,7 +214,8 @@ def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation): b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) func = tvm.build(s, [A, W, B], target) func(a, w, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + tol = get_tolerance(dtype, w_np_hwio, b_np) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=tol["rtol"], atol=tol["atol"]) if __name__ == "__main__":