@@ -5818,6 +5818,7 @@ static void emit_cfunc_invalidate(
58185818 prepare_call_in (gf_thunk->getParent (), jlapplygeneric_func));
58195819}
58205820
5821+ #include < iostream>
58215822static Function* gen_cfun_wrapper (
58225823 Module *into, jl_codegen_params_t ¶ms,
58235824 const function_sig_t &sig, jl_value_t *ff, const char *aliasname,
@@ -8704,6 +8705,58 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
87048705 return nullptr ;
87058706}
87068707
8708+ // Handle FLOAT16 ABI v2
8709+ #if JULIA_FLOAT16_ABI == 2
8710+ static void makeCastCall (Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
8711+ {
8712+ Function *calledFun = M.getFunction (calledName);
8713+ if (!calledFun) {
8714+ calledFun = Function::Create (FTcalled, Function::ExternalLinkage, calledName, M);
8715+ }
8716+ auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
8717+ auto wrapperFun = Function::Create (FTwrapper, linkage, wrapperName, M);
8718+ wrapperFun->addFnAttr (Attribute::AlwaysInline);
8719+ llvm::IRBuilder<> builder (BasicBlock::Create (M.getContext (), " top" , wrapperFun));
8720+ SmallVector<Value *, 4 > CallArgs;
8721+ if (wrapperFun->arg_size () != calledFun->arg_size ()){
8722+ llvm::errs () << " FATAL ERROR: Can't match wrapper to called function" ;
8723+ abort ();
8724+ }
8725+ for (auto wrapperArg = wrapperFun->arg_begin (), calledArg = calledFun->arg_begin ();
8726+ wrapperArg != wrapperFun->arg_end () && calledArg != calledFun->arg_end (); ++wrapperArg, ++calledArg)
8727+ {
8728+ CallArgs.push_back (builder.CreateBitCast (wrapperArg, calledArg->getType ()));
8729+ }
8730+ auto val = builder.CreateCall (calledFun, CallArgs);
8731+ auto retval = builder.CreateBitCast (val,wrapperFun->getReturnType ());
8732+ builder.CreateRet (retval);
8733+ }
8734+
8735+ void emitFloat16Wrappers (Module &M, bool external)
8736+ {
8737+ auto &ctx = M.getContext ();
8738+ makeCastCall (M, " __gnu_h2f_ieee" , " julia__gnu_h2f_ieee" , FunctionType::get (Type::getFloatTy (ctx), { Type::getHalfTy (ctx) }, false ),
8739+ FunctionType::get (Type::getFloatTy (ctx), { Type::getInt16Ty (ctx) }, false ), external);
8740+ makeCastCall (M, " __extendhfsf2" , " julia__gnu_h2f_ieee" , FunctionType::get (Type::getFloatTy (ctx), { Type::getHalfTy (ctx) }, false ),
8741+ FunctionType::get (Type::getFloatTy (ctx), { Type::getInt16Ty (ctx) }, false ), external);
8742+ makeCastCall (M, " __gnu_f2h_ieee" , " julia__gnu_f2h_ieee" , FunctionType::get (Type::getHalfTy (ctx), { Type::getFloatTy (ctx) }, false ),
8743+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getFloatTy (ctx) }, false ), external);
8744+ makeCastCall (M, " __truncsfhf2" , " julia__gnu_f2h_ieee" , FunctionType::get (Type::getHalfTy (ctx), { Type::getFloatTy (ctx) }, false ),
8745+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getFloatTy (ctx) }, false ), external);
8746+ makeCastCall (M, " __truncdfhf2" , " julia__truncdfhf2" , FunctionType::get (Type::getHalfTy (ctx), { Type::getDoubleTy (ctx) }, false ),
8747+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getDoubleTy (ctx) }, false ), external);
8748+ }
8749+
8750+ static void init_f16_funcs (void )
8751+ {
8752+ auto ctx = jl_ExecutionEngine->acquireContext ();
8753+ auto TSM = jl_create_ts_module (" F16Wrappers" , ctx, imaging_default ());
8754+ auto aliasM = TSM.getModuleUnlocked ();
8755+ emitFloat16Wrappers (*aliasM, true );
8756+ jl_ExecutionEngine->addModule (std::move (TSM));
8757+ }
8758+ #endif
8759+
87078760static void init_jit_functions (void )
87088761{
87098762 add_named_global (jlstack_chk_guard_var, &__stack_chk_guard);
@@ -8942,6 +8995,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
89428995 jl_init_llvm ();
89438996 // Now that the execution engine exists, initialize all modules
89448997 init_jit_functions ();
8998+ #if JULIA_FLOAT16_ABI == 2
8999+ init_f16_funcs ();
9000+ #endif
89459001}
89469002
89479003extern " C" JL_DLLEXPORT void jl_teardown_codegen_impl () JL_NOTSAFEPOINT
0 commit comments