@@ -43,7 +43,7 @@ using namespace llvm;
4343#include " jitlayers.h"
4444#include " julia_assert.h"
4545#include " processor.h"
46-
46+ # include < iostream >
4747#ifdef JL_USE_JITLINK
4848# if JL_LLVM_VERSION >= 140000
4949# include < llvm/ExecutionEngine/Orc/DebuggerSupportPlugin.h>
@@ -1216,6 +1216,46 @@ JuliaOJIT::PipelineT::PipelineT(orc::ObjectLayer &BaseLayer, TargetMachine &TM,
12161216 OptimizeLayer(CompileLayer.getExecutionSession(), CompileLayer,
12171217 llvm::orc::IRTransformLayer::TransformFunction(OptimizerT(TM, optlevel))) {}
12181218
1219+ static void makeCastCall (Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled)
1220+ {
1221+ Function *calledFun = M.getFunction (calledName);
1222+ if (!calledFun) {
1223+ calledFun = Function::Create (FTcalled, Function::ExternalLinkage, calledName, M);
1224+ }
1225+ auto wrapperFun = Function::Create (FTwrapper, Function::InternalLinkage, wrapperName, M);
1226+ appendToCompilerUsed (M, {wrapperFun});
1227+
1228+ llvm::IRBuilder<> builder (BasicBlock::Create (M.getContext (), " top" , wrapperFun));
1229+ SmallVector<Value *, 4 > CallArgs;
1230+ if (wrapperFun->arg_size () != calledFun->arg_size ()){
1231+ llvm::errs () << " FATAL ERROR: Can't match wrapper to called function" ;
1232+ abort ();
1233+ }
1234+ for (auto wrapperArg = wrapperFun->arg_begin (), calledArg = calledFun->arg_begin ();
1235+ wrapperArg != wrapperFun->arg_end () && calledArg != calledFun->arg_end (); ++wrapperArg, ++calledArg)
1236+ {
1237+ CallArgs.push_back (builder.CreateBitCast (wrapperArg, calledArg->getType ()));
1238+ }
1239+ auto val = builder.CreateCall (calledFun, CallArgs);
1240+ auto retval = builder.CreateBitCast (val,wrapperFun->getReturnType ());
1241+ builder.CreateRet (retval);
1242+ }
1243+
1244+ static void emitFloat16Wrappers (Module &M)
1245+ {
1246+ auto &ctx = M.getContext ();
1247+ makeCastCall (M, " __gnu_h2f_ieee" , " julia__gnu_h2f_ieee" , FunctionType::get (Type::getFloatTy (ctx), { Type::getHalfTy (ctx) }, false ),
1248+ FunctionType::get (Type::getFloatTy (ctx), { Type::getInt16Ty (ctx) }, false ));
1249+ makeCastCall (M, " __extendhfsf2" , " julia__gnu_h2f_ieee" , FunctionType::get (Type::getFloatTy (ctx), { Type::getHalfTy (ctx) }, false ),
1250+ FunctionType::get (Type::getFloatTy (ctx), { Type::getInt16Ty (ctx) }, false ));
1251+ makeCastCall (M, " __gnu_f2h_ieee" , " julia__gnu_f2h_ieee" , FunctionType::get (Type::getHalfTy (ctx), { Type::getFloatTy (ctx) }, false ),
1252+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getFloatTy (ctx) }, false ));
1253+ makeCastCall (M, " __truncsfhf2" , " julia__gnu_f2h_ieee" , FunctionType::get (Type::getHalfTy (ctx), { Type::getFloatTy (ctx) }, false ),
1254+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getFloatTy (ctx) }, false ));
1255+ makeCastCall (M, " __truncdfhf2" , " julia__truncdfhf2" , FunctionType::get (Type::getHalfTy (ctx), { Type::getDoubleTy (ctx) }, false ),
1256+ FunctionType::get (Type::getInt16Ty (ctx), { Type::getDoubleTy (ctx) }, false ));
1257+
1258+ }
12191259JuliaOJIT::JuliaOJIT ()
12201260 : TM(createTargetMachine()),
12211261 DL(jl_create_datalayout(*TM)),
@@ -1283,49 +1323,16 @@ JuliaOJIT::JuliaOJIT()
12831323 // tells DynamicLibrary to load the program, not a library.
12841324 std::string ErrorStr;
12851325
1286- const char *aliasIR = R"V0G0N"(
1287- declare external float @julia__gnu_h2f_ieee(i16);
1288- declare external i16 @julia__gnu_f2hieee(float);
1289- declare external i16 @julia__truncdfhf2(double);
1290-
1291- define external float @__gnu_h2f_ieee(half %0) unnamed_addr {
1292- top:
1293- %1 = bitcast half %0 to i16
1294- %2 = call float @julia__gnu_h2f_ieee(i16 %1)
1295- ret float %2
1296- }
1297- define external float @__extendhfsf2(half %0) unnamed_addr {
1298- top:
1299- %1 = bitcast half %0 to i16
1300- %2 = call float @julia__gnu_h2f_ieee(i16 %1)
1301- ret float %2
1302- }
1303- define external half @__gnu_f2h_ieee(float %0) unnamed_addr {
1304- top:
1305- %1 = call i16 @julia__gnu_f2hieee(float %0)
1306- %2 = bitcast i16 %1 to half
1307- ret half %2
1308- }
1309- define external half @__truncsfhf2(float %0) unnamed_addr {
1310- top:
1311- %1 = call i16 @julia__gnu_f2hieee(float %0)
1312- %2 = bitcast i16 %1 to half
1313- ret half %2
1314- }
1315- define external half @__truncdfhf2(double %0) unnamed_addr {
1316- top:
1317- %1 = call i16 @julia__truncdfhf2(double %0)
1318- %2 = bitcast i16 %1 to half
1319- ret half %2
1320- }
1321- )V0G0N"" ;
1322-
13231326 auto ctx = ContextPool.acquire ();
1324- SMDiagnostic Err = SMDiagnostic ( );
1325- auto aliasM = parseAssemblyString (aliasIR, Err, *ctx. getContext () );
1327+ std::unique_ptr<Module> aliasM = jl_create_llvm_module ( " F16Wrappers " , *ctx. getContext (), imaging_default (), DL, TM-> getTargetTriple () );
1328+ emitFloat16Wrappers (*aliasM );
13261329 jl_decorate_module (*aliasM);
13271330 shareStrings (*aliasM);
1328- aliasM->dump ();
1331+ std::string Str;
1332+ raw_string_ostream OS (Str);
1333+ OS << *aliasM;
1334+ OS.flush ();
1335+ std::cout<<Str;
13291336 cantFail (OptSelLayer.add (JD, orc::ThreadSafeModule (std::move (aliasM), ctx)));
13301337 releaseContext (std::move (ctx));
13311338 // JD.addToLinkOrder(f16InterposerJD, orc::JITDylibLookupFlags::MatchAllSymbols);
0 commit comments