Skip to content

Commit 4c99c3a

Browse files
committed
Friends don't let freinds break ABIs
1 parent 6b00d83 commit 4c99c3a

File tree

1 file changed

+48
-41
lines changed

1 file changed

+48
-41
lines changed

src/jitlayers.cpp

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ using namespace llvm;
4343
#include "jitlayers.h"
4444
#include "julia_assert.h"
4545
#include "processor.h"
46-
46+
#include <iostream>
4747
#ifdef JL_USE_JITLINK
4848
# if JL_LLVM_VERSION >= 140000
4949
# include <llvm/ExecutionEngine/Orc/DebuggerSupportPlugin.h>
@@ -1216,6 +1216,46 @@ JuliaOJIT::PipelineT::PipelineT(orc::ObjectLayer &BaseLayer, TargetMachine &TM,
12161216
OptimizeLayer(CompileLayer.getExecutionSession(), CompileLayer,
12171217
llvm::orc::IRTransformLayer::TransformFunction(OptimizerT(TM, optlevel))) {}
12181218

1219+
static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled)
1220+
{
1221+
Function *calledFun = M.getFunction(calledName);
1222+
if (!calledFun) {
1223+
calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M);
1224+
}
1225+
auto wrapperFun = Function::Create(FTwrapper, Function::InternalLinkage, wrapperName, M);
1226+
appendToCompilerUsed(M, {wrapperFun});
1227+
1228+
llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun));
1229+
SmallVector<Value *, 4> CallArgs;
1230+
if (wrapperFun->arg_size() != calledFun->arg_size()){
1231+
llvm::errs() << "FATAL ERROR: Can't match wrapper to called function";
1232+
abort();
1233+
}
1234+
for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin();
1235+
wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg)
1236+
{
1237+
CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType()));
1238+
}
1239+
auto val = builder.CreateCall(calledFun, CallArgs);
1240+
auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType());
1241+
builder.CreateRet(retval);
1242+
}
1243+
1244+
static void emitFloat16Wrappers(Module &M)
1245+
{
1246+
auto &ctx = M.getContext();
1247+
makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
1248+
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false));
1249+
makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
1250+
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false));
1251+
makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
1252+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false));
1253+
makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
1254+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false));
1255+
makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false),
1256+
FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false));
1257+
1258+
}
12191259
JuliaOJIT::JuliaOJIT()
12201260
: TM(createTargetMachine()),
12211261
DL(jl_create_datalayout(*TM)),
@@ -1283,49 +1323,16 @@ JuliaOJIT::JuliaOJIT()
12831323
// tells DynamicLibrary to load the program, not a library.
12841324
std::string ErrorStr;
12851325

1286-
const char *aliasIR = R"V0G0N"(
1287-
declare external float @julia__gnu_h2f_ieee(i16);
1288-
declare external i16 @julia__gnu_f2hieee(float);
1289-
declare external i16 @julia__truncdfhf2(double);
1290-
1291-
define external float @__gnu_h2f_ieee(half %0) unnamed_addr {
1292-
top:
1293-
%1 = bitcast half %0 to i16
1294-
%2 = call float @julia__gnu_h2f_ieee(i16 %1)
1295-
ret float %2
1296-
}
1297-
define external float @__extendhfsf2(half %0) unnamed_addr {
1298-
top:
1299-
%1 = bitcast half %0 to i16
1300-
%2 = call float @julia__gnu_h2f_ieee(i16 %1)
1301-
ret float %2
1302-
}
1303-
define external half @__gnu_f2h_ieee(float %0) unnamed_addr {
1304-
top:
1305-
%1 = call i16 @julia__gnu_f2hieee(float %0)
1306-
%2 = bitcast i16 %1 to half
1307-
ret half %2
1308-
}
1309-
define external half @__truncsfhf2(float %0) unnamed_addr {
1310-
top:
1311-
%1 = call i16 @julia__gnu_f2hieee(float %0)
1312-
%2 = bitcast i16 %1 to half
1313-
ret half %2
1314-
}
1315-
define external half @__truncdfhf2(double %0) unnamed_addr {
1316-
top:
1317-
%1 = call i16 @julia__truncdfhf2(double %0)
1318-
%2 = bitcast i16 %1 to half
1319-
ret half %2
1320-
}
1321-
)V0G0N"";
1322-
13231326
auto ctx = ContextPool.acquire();
1324-
SMDiagnostic Err = SMDiagnostic();
1325-
auto aliasM = parseAssemblyString(aliasIR, Err, *ctx.getContext());
1327+
std::unique_ptr<Module> aliasM = jl_create_llvm_module("F16Wrappers", *ctx.getContext(), imaging_default(), DL, TM->getTargetTriple());
1328+
emitFloat16Wrappers(*aliasM);
13261329
jl_decorate_module(*aliasM);
13271330
shareStrings(*aliasM);
1328-
aliasM->dump();
1331+
std::string Str;
1332+
raw_string_ostream OS(Str);
1333+
OS << *aliasM;
1334+
OS.flush();
1335+
std::cout<<Str;
13291336
cantFail(OptSelLayer.add(JD, orc::ThreadSafeModule(std::move(aliasM), ctx)));
13301337
releaseContext(std::move(ctx));
13311338
// JD.addToLinkOrder(f16InterposerJD, orc::JITDylibLookupFlags::MatchAllSymbols);

0 commit comments

Comments
 (0)