Skip to content

Commit ed77914

Browse files
vchuravygbaraldi
andcommitted
Support both Float16 ABIs depending on LLVM and platform
There are two Float16 ABIs in the wild, one for platforms that have a defing register and the original one where we used i16. LLVM 15 follows GCC and uses the new ABI on x86/ARM but not PPC. Co-authored-by: Gabriel Baraldi <[email protected]>
1 parent b12ddca commit ed77914

File tree

4 files changed

+76
-2
lines changed

4 files changed

+76
-2
lines changed

src/aotcompile.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ static void reportWriterError(const ErrorInfoBase &E)
493493
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
494494
}
495495

496+
#if JULIA_FLOAT16_ABI == 1
496497
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
497498
{
498499
Function *target = M.getFunction(alias);
@@ -509,7 +510,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
509510
auto val = builder.CreateCall(target, CallArgs);
510511
builder.CreateRet(val);
511512
}
512-
513+
#endif
513514
void multiversioning_preannotate(Module &M);
514515

515516
// See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t.
@@ -942,6 +943,8 @@ struct ShardTimers {
942943
}
943944
};
944945

946+
void emitFloat16Wrappers(Module &M, bool external);
947+
945948
// Perform the actual optimization and emission of the output files
946949
static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *outputs, const std::string *names,
947950
NewArchiveMember *unopt, NewArchiveMember *opt, NewArchiveMember *obj, NewArchiveMember *asm_,
@@ -1002,7 +1005,9 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
10021005
}
10031006
}
10041007
// no need to inject aliases if we have no functions
1008+
10051009
if (inject_aliases) {
1010+
#if JULIA_FLOAT16_ABI == 1
10061011
// We would like to emit an alias or an weakref alias to redirect these symbols
10071012
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
10081013
// So for now we inject a definition of these functions that calls our runtime
@@ -1017,8 +1022,10 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
10171022
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
10181023
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
10191024
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
1025+
#else
1026+
emitFloat16Wrappers(M, false);
1027+
#endif
10201028
}
1021-
10221029
timers.optimize.stopTimer();
10231030

10241031
if (opt) {

src/codegen.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5809,6 +5809,7 @@ static void emit_cfunc_invalidate(
58095809
prepare_call_in(gf_thunk->getParent(), jlapplygeneric_func));
58105810
}
58115811

5812+
#include <iostream>
58125813
static Function* gen_cfun_wrapper(
58135814
Module *into, jl_codegen_params_t &params,
58145815
const function_sig_t &sig, jl_value_t *ff, const char *aliasname,
@@ -8696,6 +8697,57 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
86968697
}
86978698
return nullptr;
86988699
}
8700+
//Float16 fun
8701+
static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
8702+
{
8703+
Function *calledFun = M.getFunction(calledName);
8704+
if (!calledFun) {
8705+
calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M);
8706+
}
8707+
auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
8708+
auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M);
8709+
wrapperFun->addFnAttr(Attribute::AlwaysInline);
8710+
llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun));
8711+
SmallVector<Value *, 4> CallArgs;
8712+
if (wrapperFun->arg_size() != calledFun->arg_size()){
8713+
llvm::errs() << "FATAL ERROR: Can't match wrapper to called function";
8714+
abort();
8715+
}
8716+
for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin();
8717+
wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg)
8718+
{
8719+
CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType()));
8720+
}
8721+
auto val = builder.CreateCall(calledFun, CallArgs);
8722+
auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType());
8723+
builder.CreateRet(retval);
8724+
}
8725+
8726+
#if JULIA_FLOAT16_ABI == 2
8727+
void emitFloat16Wrappers(Module &M, bool external)
8728+
{
8729+
auto &ctx = M.getContext();
8730+
makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
8731+
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
8732+
makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
8733+
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
8734+
makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
8735+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
8736+
makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
8737+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
8738+
makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false),
8739+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external);
8740+
}
8741+
8742+
static void init_f16_funcs(void)
8743+
{
8744+
auto ctx = jl_ExecutionEngine->acquireContext();
8745+
auto TSM = jl_create_ts_module("F16Wrappers", ctx, imaging_default());
8746+
auto aliasM = TSM.getModuleUnlocked();
8747+
emitFloat16Wrappers(*aliasM, true);
8748+
jl_ExecutionEngine->addModule(std::move(TSM));
8749+
}
8750+
#endif
86998751

87008752
static void init_jit_functions(void)
87018753
{
@@ -8935,6 +8987,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
89358987
jl_init_llvm();
89368988
// Now that the execution engine exists, initialize all modules
89378989
init_jit_functions();
8990+
#if JULIA_FLOAT16_ABI == 2
8991+
init_f16_funcs();
8992+
#endif
89388993
}
89398994

89408995
extern "C" JL_DLLEXPORT void jl_teardown_codegen_impl() JL_NOTSAFEPOINT

src/jitlayers.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ JuliaOJIT::JuliaOJIT()
13831383

13841384
JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
13851385

1386+
#if JULIA_FLOAT16_ABI == 1
13861387
orc::SymbolAliasMap jl_crt = {
13871388
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
13881389
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
@@ -1391,6 +1392,7 @@ JuliaOJIT::JuliaOJIT()
13911392
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
13921393
};
13931394
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
1395+
#endif
13941396

13951397
#ifdef MSAN_EMUTLS_WORKAROUND
13961398
orc::SymbolMap msan_crt;

src/llvm-version.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <llvm/Config/llvm-config.h>
44
#include "julia_assert.h"
5+
#include "platform.h"
56

67
// The LLVM version used, JL_LLVM_VERSION, is represented as a 5-digit integer
78
// of the form ABBCC, where A is the major version, B is minor, and C is patch.
@@ -17,6 +18,15 @@
1718
#define JL_LLVM_OPAQUE_POINTERS 1
1819
#endif
1920

21+
// Pre GCC 12 libgcc defined the ABI for Float16->Float32
22+
// to take an i16. GCC 12 silently changed the ABI to now pass
23+
// Float16 in Float32 registers.
24+
#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_)
25+
#define JULIA_FLOAT16_ABI 1
26+
#else
27+
#define JULIA_FLOAT16_ABI 2
28+
#endif
29+
2030
#ifdef __cplusplus
2131
#if defined(__GNUC__) && (__GNUC__ >= 9)
2232
// Added in GCC 9, this warning is annoying

0 commit comments

Comments
 (0)