@@ -5809,6 +5809,7 @@ static void emit_cfunc_invalidate(
58095809 prepare_call_in (gf_thunk->getParent (), jlapplygeneric_func));
58105810}
58115811
5812+ #include < iostream>
58125813static Function* gen_cfun_wrapper (
58135814 Module *into, jl_codegen_params_t ¶ms,
58145815 const function_sig_t &sig, jl_value_t *ff, const char *aliasname,
@@ -8696,6 +8697,57 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
86968697 }
86978698 return nullptr ;
86988699}
8700+ // Float16 fun
8701+ static void makeCastCall (Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
8702+ {
8703+ Function *calledFun = M.getFunction (calledName);
8704+ if (!calledFun) {
8705+ calledFun = Function::Create (FTcalled, Function::ExternalLinkage, calledName, M);
8706+ }
8707+ auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
8708+ auto wrapperFun = Function::Create (FTwrapper, linkage, wrapperName, M);
8709+ wrapperFun->addFnAttr (Attribute::AlwaysInline);
8710+ llvm::IRBuilder<> builder (BasicBlock::Create (M.getContext (), " top" , wrapperFun));
8711+ SmallVector<Value *, 4 > CallArgs;
8712+ if (wrapperFun->arg_size () != calledFun->arg_size ()){
8713+ llvm::errs () << " FATAL ERROR: Can't match wrapper to called function" ;
8714+ abort ();
8715+ }
8716+ for (auto wrapperArg = wrapperFun->arg_begin (), calledArg = calledFun->arg_begin ();
8717+ wrapperArg != wrapperFun->arg_end () && calledArg != calledFun->arg_end (); ++wrapperArg, ++calledArg)
8718+ {
8719+ CallArgs.push_back (builder.CreateBitCast (wrapperArg, calledArg->getType ()));
8720+ }
8721+ auto val = builder.CreateCall (calledFun, CallArgs);
8722+ auto retval = builder.CreateBitCast (val,wrapperFun->getReturnType ());
8723+ builder.CreateRet (retval);
8724+ }
8725+
8726+ #if JULIA_FLOAT16_ABI == 2
8727+ void emitFloat16Wrappers (Module &M, bool external)
8728+ {
8729+ auto &ctx = M.getContext ();
8730+ makeCastCall (M, " __gnu_h2f_ieee" , " julia__gnu_h2f_ieee" , FunctionType::get (Type::getFloatTy (ctx), { Type::getHalfTy (ctx) }, false ),
8731+ FunctionType::get (Type::getFloatTy (ctx), { Type::getInt16Ty (ctx) }, false ), external);
8732+ makeCastCall (M, " __extendhfsf2" , " julia__gnu_h2f_ieee" , FunctionType::get (Type::getFloatTy (ctx), { Type::getHalfTy (ctx) }, false ),
8733+ FunctionType::get (Type::getFloatTy (ctx), { Type::getInt16Ty (ctx) }, false ), external);
8734+ makeCastCall (M, " __gnu_f2h_ieee" , " julia__gnu_f2h_ieee" , FunctionType::get (Type::getHalfTy (ctx), { Type::getFloatTy (ctx) }, false ),
8735+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getFloatTy (ctx) }, false ), external);
8736+ makeCastCall (M, " __truncsfhf2" , " julia__gnu_f2h_ieee" , FunctionType::get (Type::getHalfTy (ctx), { Type::getFloatTy (ctx) }, false ),
8737+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getFloatTy (ctx) }, false ), external);
8738+ makeCastCall (M, " __truncdfhf2" , " julia__truncdfhf2" , FunctionType::get (Type::getHalfTy (ctx), { Type::getDoubleTy (ctx) }, false ),
8739+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getDoubleTy (ctx) }, false ), external);
8740+ }
8741+
8742+ static void init_f16_funcs (void )
8743+ {
8744+ auto ctx = jl_ExecutionEngine->acquireContext ();
8745+ auto TSM = jl_create_ts_module (" F16Wrappers" , ctx, imaging_default ());
8746+ auto aliasM = TSM.getModuleUnlocked ();
8747+ emitFloat16Wrappers (*aliasM, true );
8748+ jl_ExecutionEngine->addModule (std::move (TSM));
8749+ }
8750+ #endif
86998751
87008752static void init_jit_functions (void )
87018753{
@@ -8935,6 +8987,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
89358987 jl_init_llvm ();
89368988 // Now that the execution engine exists, initialize all modules
89378989 init_jit_functions ();
8990+ #if JULIA_FLOAT16_ABI == 2
8991+ init_f16_funcs ();
8992+ #endif
89388993}
89398994
89408995extern " C" JL_DLLEXPORT void jl_teardown_codegen_impl () JL_NOTSAFEPOINT
0 commit comments