Skip to content

Commit 959902f

Browse files
vchuravygbaraldi
andauthored
Support both Float16 ABIs depending on LLVM and platform (#49527)
There are two Float16 ABIs in the wild, one for platforms that have a defing register and the original one where we used i16. LLVM 15 follows GCC and uses the new ABI on x86/ARM but not PPC. Co-authored-by: Gabriel Baraldi <[email protected]>
1 parent 09a0f34 commit 959902f

File tree

4 files changed

+77
-2
lines changed

4 files changed

+77
-2
lines changed

src/aotcompile.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ static void reportWriterError(const ErrorInfoBase &E)
494494
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
495495
}
496496

497+
#if JULIA_FLOAT16_ABI == 1
497498
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
498499
{
499500
Function *target = M.getFunction(alias);
@@ -510,7 +511,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
510511
auto val = builder.CreateCall(target, CallArgs);
511512
builder.CreateRet(val);
512513
}
513-
514+
#endif
514515
void multiversioning_preannotate(Module &M);
515516

516517
// See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t.
@@ -943,6 +944,8 @@ struct ShardTimers {
943944
}
944945
};
945946

947+
void emitFloat16Wrappers(Module &M, bool external);
948+
946949
// Perform the actual optimization and emission of the output files
947950
static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *outputs, const std::string *names,
948951
NewArchiveMember *unopt, NewArchiveMember *opt, NewArchiveMember *obj, NewArchiveMember *asm_,
@@ -1003,7 +1006,9 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
10031006
}
10041007
}
10051008
// no need to inject aliases if we have no functions
1009+
10061010
if (inject_aliases) {
1011+
#if JULIA_FLOAT16_ABI == 1
10071012
// We would like to emit an alias or an weakref alias to redirect these symbols
10081013
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
10091014
// So for now we inject a definition of these functions that calls our runtime
@@ -1018,8 +1023,10 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
10181023
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
10191024
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
10201025
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
1026+
#else
1027+
emitFloat16Wrappers(M, false);
1028+
#endif
10211029
}
1022-
10231030
timers.optimize.stopTimer();
10241031

10251032
if (opt) {

src/codegen.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5818,6 +5818,7 @@ static void emit_cfunc_invalidate(
58185818
prepare_call_in(gf_thunk->getParent(), jlapplygeneric_func));
58195819
}
58205820

5821+
#include <iostream>
58215822
static Function* gen_cfun_wrapper(
58225823
Module *into, jl_codegen_params_t &params,
58235824
const function_sig_t &sig, jl_value_t *ff, const char *aliasname,
@@ -8704,6 +8705,58 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
87048705
return nullptr;
87058706
}
87068707

8708+
// Handle FLOAT16 ABI v2
8709+
#if JULIA_FLOAT16_ABI == 2
8710+
static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
8711+
{
8712+
Function *calledFun = M.getFunction(calledName);
8713+
if (!calledFun) {
8714+
calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M);
8715+
}
8716+
auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
8717+
auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M);
8718+
wrapperFun->addFnAttr(Attribute::AlwaysInline);
8719+
llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun));
8720+
SmallVector<Value *, 4> CallArgs;
8721+
if (wrapperFun->arg_size() != calledFun->arg_size()){
8722+
llvm::errs() << "FATAL ERROR: Can't match wrapper to called function";
8723+
abort();
8724+
}
8725+
for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin();
8726+
wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg)
8727+
{
8728+
CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType()));
8729+
}
8730+
auto val = builder.CreateCall(calledFun, CallArgs);
8731+
auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType());
8732+
builder.CreateRet(retval);
8733+
}
8734+
8735+
void emitFloat16Wrappers(Module &M, bool external)
8736+
{
8737+
auto &ctx = M.getContext();
8738+
makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
8739+
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
8740+
makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
8741+
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
8742+
makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
8743+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
8744+
makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
8745+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
8746+
makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false),
8747+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external);
8748+
}
8749+
8750+
static void init_f16_funcs(void)
8751+
{
8752+
auto ctx = jl_ExecutionEngine->acquireContext();
8753+
auto TSM = jl_create_ts_module("F16Wrappers", ctx, imaging_default());
8754+
auto aliasM = TSM.getModuleUnlocked();
8755+
emitFloat16Wrappers(*aliasM, true);
8756+
jl_ExecutionEngine->addModule(std::move(TSM));
8757+
}
8758+
#endif
8759+
87078760
static void init_jit_functions(void)
87088761
{
87098762
add_named_global(jlstack_chk_guard_var, &__stack_chk_guard);
@@ -8942,6 +8995,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
89428995
jl_init_llvm();
89438996
// Now that the execution engine exists, initialize all modules
89448997
init_jit_functions();
8998+
#if JULIA_FLOAT16_ABI == 2
8999+
init_f16_funcs();
9000+
#endif
89459001
}
89469002

89479003
extern "C" JL_DLLEXPORT void jl_teardown_codegen_impl() JL_NOTSAFEPOINT

src/jitlayers.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ JuliaOJIT::JuliaOJIT()
13831383

13841384
JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
13851385

1386+
#if JULIA_FLOAT16_ABI == 1
13861387
orc::SymbolAliasMap jl_crt = {
13871388
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
13881389
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
@@ -1391,6 +1392,7 @@ JuliaOJIT::JuliaOJIT()
13911392
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
13921393
};
13931394
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
1395+
#endif
13941396

13951397
#ifdef MSAN_EMUTLS_WORKAROUND
13961398
orc::SymbolMap msan_crt;

src/llvm-version.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <llvm/Config/llvm-config.h>
44
#include "julia_assert.h"
5+
#include "platform.h"
56

67
// The LLVM version used, JL_LLVM_VERSION, is represented as a 5-digit integer
78
// of the form ABBCC, where A is the major version, B is minor, and C is patch.
@@ -17,6 +18,15 @@
1718
#define JL_LLVM_OPAQUE_POINTERS 1
1819
#endif
1920

21+
// Pre GCC 12 libgcc defined the ABI for Float16->Float32
22+
// to take an i16. GCC 12 silently changed the ABI to now pass
23+
// Float16 in Float32 registers.
24+
#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_)
25+
#define JULIA_FLOAT16_ABI 1
26+
#else
27+
#define JULIA_FLOAT16_ABI 2
28+
#endif
29+
2030
#ifdef __cplusplus
2131
#if defined(__GNUC__) && (__GNUC__ >= 9)
2232
// Added in GCC 9, this warning is annoying

0 commit comments

Comments
 (0)