|
| 1 | +"""Testing if we can generate code in topi style""" |
| 2 | + |
| 3 | +import tvm |
| 4 | +from tvm.contrib import util |
| 5 | +from tvm.contrib.pickle_memoize import memoize |
| 6 | +import topi |
| 7 | +import topi.testing |
| 8 | +import vta |
| 9 | +import vta.testing |
| 10 | +import numpy as np |
| 11 | + |
| 12 | +Workload = vta.top.vta_conv2d.Workload |
| 13 | + |
| 14 | +@tvm.tag_scope(tag=topi.tag.ELEMWISE) |
| 15 | +def my_clip(x, a_min, a_max): |
| 16 | + """Unlike topi's current clip, put min and max into two stages.""" |
| 17 | + const_min = tvm.const(a_min, x.dtype) |
| 18 | + const_max = tvm.const(a_max, x.dtype) |
| 19 | + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") |
| 20 | + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") |
| 21 | + return x |
| 22 | + |
| 23 | + |
| 24 | +def test_vta_conv2d(): |
| 25 | + def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True): |
| 26 | + data_shape = (batch_size, wl.in_filter // env.BLOCK_IN, |
| 27 | + wl.height, wl.width, env.BLOCK_IN) |
| 28 | + kernel_shape = (wl.out_filter // env.BLOCK_OUT, |
| 29 | + wl.in_filter // env.BLOCK_IN, |
| 30 | + wl.hkernel, wl.wkernel, |
| 31 | + env.BLOCK_OUT, env.BLOCK_IN) |
| 32 | + bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) |
| 33 | + |
| 34 | + |
| 35 | + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 |
| 36 | + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 |
| 37 | + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) |
| 38 | + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) |
| 39 | + bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype) |
| 40 | + |
| 41 | + res_conv = vta.top.packed_conv2d( |
| 42 | + data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride)) |
| 43 | + res = topi.right_shift(res_conv, 8) |
| 44 | + res = topi.broadcast_add(res, bias) |
| 45 | + res = my_clip(res, 0, 127) |
| 46 | + res = topi.cast(res, "int8") |
| 47 | + |
| 48 | + num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter |
| 49 | + |
| 50 | + a_shape = (batch_size, wl.in_filter, wl.height, wl.width) |
| 51 | + w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) |
| 52 | + stride = (wl.hstride, wl.wstride) |
| 53 | + data_dtype = data.dtype |
| 54 | + acc_dtype = env.acc_dtype |
| 55 | + assert wl.hpad == wl.wpad |
| 56 | + padding = wl.hpad |
| 57 | + |
| 58 | + @memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc") |
| 59 | + def get_ref_data(): |
| 60 | + a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) |
| 61 | + w_np = (np.random.uniform(size=w_shape) * 4).astype(data_dtype) |
| 62 | + a_np = np.abs(a_np) |
| 63 | + w_np = np.abs(w_np) |
| 64 | + b_np = topi.testing.conv2d_nchw_python( |
| 65 | + a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) |
| 66 | + return a_np, w_np, b_np |
| 67 | + |
| 68 | + |
| 69 | + def verify(s, check_correctness): |
| 70 | + mod = vta.build(s, [data, kernel, bias, res], "ext_dev", |
| 71 | + env.target_host, name="conv2d") |
| 72 | + temp = util.tempdir() |
| 73 | + |
| 74 | + mod.save(temp.relpath("conv2d.o")) |
| 75 | + remote.upload(temp.relpath("conv2d.o")) |
| 76 | + f = remote.load_module("conv2d.o") |
| 77 | + # verify |
| 78 | + ctx = remote.ext_dev(0) |
| 79 | + # Data in original format |
| 80 | + data_orig, kernel_orig, res_ref = get_ref_data() |
| 81 | + bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32") |
| 82 | + bias_orig = np.abs(bias_orig) |
| 83 | + |
| 84 | + data_packed = data_orig.reshape( |
| 85 | + batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, |
| 86 | + wl.height, wl.width).transpose((0, 1, 3, 4, 2)) |
| 87 | + kernel_packed = kernel_orig.reshape( |
| 88 | + wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT, |
| 89 | + wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, |
| 90 | + wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) |
| 91 | + bias_packed = bias_orig.reshape( |
| 92 | + wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) |
| 93 | + res_shape = topi.util.get_const_tuple(res.shape) |
| 94 | + |
| 95 | + res_np = np.zeros(res_shape).astype(res.dtype) |
| 96 | + data_arr = tvm.nd.array(data_packed, ctx) |
| 97 | + kernel_arr = tvm.nd.array(kernel_packed, ctx) |
| 98 | + bias_arr = tvm.nd.array(bias_packed, ctx) |
| 99 | + res_arr = tvm.nd.array(res_np, ctx) |
| 100 | + time_f = f.time_evaluator("conv2d", ctx, number=5) |
| 101 | + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) |
| 102 | + res_unpack = res_arr.asnumpy().transpose( |
| 103 | + (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) |
| 104 | + if check_correctness: |
| 105 | + assert wl.hpad == wl.wpad |
| 106 | + stride = (wl.hstride, wl.wstride) |
| 107 | + padding = wl.hpad |
| 108 | + res_ref = res_ref >> 8 |
| 109 | + res_ref += bias_orig.reshape(wl.out_filter, 1, 1) |
| 110 | + res_ref = np.clip(res_ref, 0, 127).astype("int8") |
| 111 | + np.testing.assert_allclose(res_unpack, res_ref) |
| 112 | + return cost |
| 113 | + |
| 114 | + def conv_normal(print_ir): |
| 115 | + print("----- CONV2D End-to-End Test-------") |
| 116 | + with vta.build_config(): |
| 117 | + s = vta.top.schedule_packed_conv2d([res]) |
| 118 | + if print_ir: |
| 119 | + print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) |
| 120 | + cost = verify(s, True) |
| 121 | + gops = (num_ops / cost.mean) / float(10 ** 9) |
| 122 | + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) |
| 123 | + |
| 124 | + conv_normal(False) |
| 125 | + |
| 126 | + def _run(env, remote): |
| 127 | + # ResNet18 workloads |
| 128 | + resnet = { |
| 129 | + # Workloads of resnet18 on imagenet |
| 130 | + 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), |
| 131 | + 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), |
| 132 | + 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), |
| 133 | + 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), |
| 134 | + 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), |
| 135 | + 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), |
| 136 | + 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), |
| 137 | + 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), |
| 138 | + 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), |
| 139 | + 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), |
| 140 | + 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), |
| 141 | + 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), |
| 142 | + } |
| 143 | + |
| 144 | + batch_size = 1 |
| 145 | + for i in range(0, len(resnet)): |
| 146 | + wl = resnet[i] |
| 147 | + key = "resnet-cfg[%d]" % i |
| 148 | + print("key=%s" % key) |
| 149 | + print(wl) |
| 150 | + run_vta_conv2d(env, remote, key, batch_size, wl) |
| 151 | + vta.testing.run(_run) |
| 152 | + |
| 153 | + |
| 154 | +if __name__ == "__main__": |
| 155 | + test_vta_conv2d() |
0 commit comments