Skip to content

Commit 281c113

Browse files
committed
Fix vcvtph2ps codegen
1 parent 4ac64fc commit 281c113

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

src/codegen/llvm/codegen_x86_64.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file codegen_arm.cc
4+
* \brief X86-64 specific code generator
5+
*/
6+
#ifdef TVM_LLVM_VERSION
7+
#include "codegen_cpu.h"
8+
9+
namespace tvm {
10+
namespace codegen {
11+
12+
class CodeGenX86_64 final : public CodeGenCPU {
13+
public:
14+
llvm::Value* VisitExpr_(const Cast* op) override;
15+
};
16+
17+
llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
18+
// LLVM does not automatically generate the correct instruction sequences for
19+
// half -> float conversion (using AVX2/AVX512 variants of vcvtph2ps).
20+
const auto from = op->value.type();
21+
const auto to = op->type;
22+
if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) {
23+
CHECK_EQ(from.lanes(), to.lanes());
24+
CHECK_NOTNULL(target_machine_);
25+
26+
const auto has_f16c =
27+
target_machine_->getTargetFeatureString().find("f16c") != llvm::StringRef::npos;
28+
const auto has_avx512f =
29+
target_machine_->getTargetFeatureString().find("avx512f") != llvm::StringRef::npos;
30+
31+
// TODO: implement version generic over lanes.
32+
if (from.lanes() == 8 && (has_f16c || has_avx512f)) {
33+
Array<Expr> vcvt_args;
34+
::llvm::Intrinsic::ID vcvtph2ps_id = ::llvm::Intrinsic::x86_vcvtph2ps_256;
35+
vcvt_args.push_back(ir::UIntImm::make(UInt(32), vcvtph2ps_id));
36+
vcvt_args.push_back(ir::UIntImm::make(UInt(32), 0));
37+
vcvt_args.push_back(
38+
ir::Call::make(Int(16, 8), ir::Call::reinterpret, {op->value}, ir::Call::PureIntrinsic));
39+
return MakeValue(ir::Call::make(to, "llvm_intrin", vcvt_args, ir::Call::PureIntrinsic));
40+
}
41+
42+
// TODO: implement version generic over lanes.
43+
if (from.lanes() == 16 && has_avx512f) {
44+
Array<Expr> vcvt_args;
45+
::llvm::Intrinsic::ID vcvtph2ps_id = ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512;
46+
vcvt_args.push_back(ir::UIntImm::make(UInt(32), vcvtph2ps_id));
47+
vcvt_args.push_back(ir::UIntImm::make(UInt(32), 0));
48+
vcvt_args.push_back(
49+
ir::Call::make(Int(16, 16), ir::Call::reinterpret, {op->value}, ir::Call::PureIntrinsic));
50+
vcvt_args.push_back(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), 16));
51+
vcvt_args.push_back(ir::IntImm::make(Int(16), -1));
52+
vcvt_args.push_back(ir::IntImm::make(Int(32), 4));
53+
return MakeValue(ir::Call::make(to, "llvm_intrin", vcvt_args, ir::Call::PureIntrinsic));
54+
}
55+
}
56+
57+
return CodeGenCPU::VisitExpr_(op);
58+
}
59+
60+
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
61+
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
62+
CodeGenLLVM* cg = new CodeGenX86_64();
63+
*rv = static_cast<void*>(cg);
64+
});
65+
66+
} // namespace codegen
67+
} // namespace tvm
68+
#endif // TVM_LLVM_VERSION
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import tvm
2+
import re
3+
import os
4+
import ctypes
5+
6+
def test_fp16_to_fp32_with_f16c():
7+
target = 'llvm -mcpu=core-avx2 -mattr=+f16c'
8+
elements = 64
9+
n = tvm.convert(elements)
10+
A = tvm.placeholder((n, 8), dtype="float16", name='A')
11+
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
12+
s = tvm.create_schedule(B.op)
13+
s[B].vectorize(s[B].op.axis[1])
14+
f = tvm.build(s, [A, B], target)
15+
16+
# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
17+
ll = f.get_source('ll')
18+
assembly = f.get_source('asm').splitlines()
19+
matches = [l for l in assembly if re.search("vcvtph2ps.*ymm", l)]
20+
assert (len(matches) > 1)
21+
22+
def test_fp16_to_fp32_with_avx512():
23+
target = 'llvm -mcpu=skylake-avx512 -mattr=+avx512f,+f16c'
24+
elements = 64
25+
n = tvm.convert(elements)
26+
A = tvm.placeholder((n, 16), dtype="float16", name='A')
27+
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
28+
s = tvm.create_schedule(B.op)
29+
s[B].vectorize(s[B].op.axis[1])
30+
f = tvm.build(s, [A, B], target)
31+
32+
# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
33+
ll = f.get_source('ll')
34+
assembly = f.get_source('asm').splitlines()
35+
matches = [l for l in assembly if re.search("vcvtph2ps.*zmm", l)]
36+
assert (len(matches) > 1)
37+
38+
def test_fp16_to_fp32_without_f16c():
39+
target = 'llvm'
40+
elements = 64
41+
n = tvm.convert(elements)
42+
A = tvm.placeholder((n, 8), dtype="float16", name='A')
43+
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
44+
s = tvm.create_schedule(B.op)
45+
s[B].vectorize(s[B].op.axis[1])
46+
f = tvm.build(s, [A, B], target)
47+
48+
# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
49+
ll = f.get_source('ll')
50+
assembly = f.get_source('asm').splitlines()
51+
matches = [l for l in assembly if re.search("vcvtph2ps.*ymm", l)]
52+
assert (len(matches) == 0)
53+
matches = [l for l in assembly if re.search("vcvtph2ps.*zmm", l)]
54+
assert (len(matches) == 0)
55+
56+
if __name__ == "__main__":
57+
test_fp16_to_fp32_with_f16c()
58+
test_fp16_to_fp32_without_f16c()
59+
test_fp16_to_fp32_with_avx512()

0 commit comments

Comments
 (0)