|
| 1 | +import tvm |
| 2 | +import re |
| 3 | +import os |
| 4 | +import ctypes |
| 5 | + |
| 6 | +def test_fp16_to_fp32_with_f16c(): |
| 7 | + target = 'llvm -mcpu=core-avx2 -mattr=+f16c' |
| 8 | + elements = 64 |
| 9 | + n = tvm.convert(elements) |
| 10 | + A = tvm.placeholder((n, 8), dtype="float16", name='A') |
| 11 | + B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B') |
| 12 | + s = tvm.create_schedule(B.op) |
| 13 | + s[B].vectorize(s[B].op.axis[1]) |
| 14 | + f = tvm.build(s, [A, B], target) |
| 15 | + |
| 16 | + # Verify we see the correct number of vpaddl and vcnt instructions in the assembly |
| 17 | + ll = f.get_source('ll') |
| 18 | + assembly = f.get_source('asm').splitlines() |
| 19 | + matches = [l for l in assembly if re.search("vcvtph2ps.*ymm", l)] |
| 20 | + assert (len(matches) > 1) |
| 21 | + |
| 22 | +def test_fp16_to_fp32_with_avx512(): |
| 23 | + target = 'llvm -mcpu=skylake-avx512 -mattr=+avx512f,+f16c' |
| 24 | + elements = 64 |
| 25 | + n = tvm.convert(elements) |
| 26 | + A = tvm.placeholder((n, 16), dtype="float16", name='A') |
| 27 | + B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B') |
| 28 | + s = tvm.create_schedule(B.op) |
| 29 | + s[B].vectorize(s[B].op.axis[1]) |
| 30 | + f = tvm.build(s, [A, B], target) |
| 31 | + |
| 32 | + # Verify we see the correct number of vpaddl and vcnt instructions in the assembly |
| 33 | + ll = f.get_source('ll') |
| 34 | + assembly = f.get_source('asm').splitlines() |
| 35 | + matches = [l for l in assembly if re.search("vcvtph2ps.*zmm", l)] |
| 36 | + assert (len(matches) > 1) |
| 37 | + |
| 38 | +def test_fp16_to_fp32_without_f16c(): |
| 39 | + target = 'llvm' |
| 40 | + elements = 64 |
| 41 | + n = tvm.convert(elements) |
| 42 | + A = tvm.placeholder((n, 8), dtype="float16", name='A') |
| 43 | + B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B') |
| 44 | + s = tvm.create_schedule(B.op) |
| 45 | + s[B].vectorize(s[B].op.axis[1]) |
| 46 | + f = tvm.build(s, [A, B], target) |
| 47 | + |
| 48 | + # Verify we see the correct number of vpaddl and vcnt instructions in the assembly |
| 49 | + ll = f.get_source('ll') |
| 50 | + assembly = f.get_source('asm').splitlines() |
| 51 | + matches = [l for l in assembly if re.search("vcvtph2ps.*ymm", l)] |
| 52 | + assert (len(matches) == 0) |
| 53 | + matches = [l for l in assembly if re.search("vcvtph2ps.*zmm", l)] |
| 54 | + assert (len(matches) == 0) |
| 55 | + |
| 56 | +if __name__ == "__main__": |
| 57 | + test_fp16_to_fp32_with_f16c() |
| 58 | + test_fp16_to_fp32_without_f16c() |
| 59 | + test_fp16_to_fp32_with_avx512() |
0 commit comments