Skip to content

Commit ab8a1de

Browse files
committed
[LLVM][PM] Make MultiVersioning NewPM compatible
1 parent cc66c03 commit ab8a1de

File tree

3 files changed

+69
-31
lines changed

3 files changed

+69
-31
lines changed

src/aotcompile.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,10 @@ static void registerCallbacks(PassBuilder &PB) {
934934
PM.addPass(RemoveJuliaAddrspacesPass());
935935
return true;
936936
}
937+
if (Name == "MultiVersioning") {
938+
PM.addPass(MultiVersioning());
939+
return true;
940+
}
937941
return false;
938942
});
939943

src/llvm-multiversioning.cpp

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// LLVM pass to clone function for different archs
88

99
#include "llvm-version.h"
10+
#include "passes.h"
1011

1112
#include <llvm-c/Core.h>
1213
#include <llvm-c/Types.h>
@@ -46,8 +47,6 @@ namespace {
4647
constexpr uint32_t clone_mask =
4748
JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU;
4849

49-
struct MultiVersioning;
50-
5150
// Treat identical mapping as missing and return `def` in that case.
5251
// We mainly need this to identify cloned function using value map after LLVM cloning
5352
// functions fills the map with identity entries.
@@ -243,7 +242,7 @@ struct CloneCtx {
243242
return cast<Function>(vmap->lookup(orig_f));
244243
}
245244
};
246-
CloneCtx(MultiVersioning *pass, Module &M);
245+
CloneCtx(Module &M, function_ref<LoopInfo&()> GetLI, function_ref<CallGraph&()> GetCG);
247246
void clone_bases();
248247
void collect_func_infos();
249248
void clone_all_partials();
@@ -277,12 +276,14 @@ struct CloneCtx {
277276
Type *T_void;
278277
PointerType *T_psize;
279278
MDNode *tbaa_const;
280-
MultiVersioning *pass;
281279
std::vector<jl_target_spec_t> specs;
282280
std::vector<Group> groups{};
283281
std::vector<Function*> fvars;
284282
std::vector<Constant*> gvars;
285283
Module &M;
284+
function_ref<LoopInfo&()> GetLI;
285+
function_ref<CallGraph&()> GetCG;
286+
286287
// Map from original functiton to one based index in `fvars`
287288
std::map<const Function*,uint32_t> func_ids{};
288289
std::vector<Function*> orig_funcs{};
@@ -298,23 +299,6 @@ struct CloneCtx {
298299
bool has_cloneall{false};
299300
};
300301

301-
struct MultiVersioning: public ModulePass {
302-
static char ID;
303-
MultiVersioning()
304-
: ModulePass(ID)
305-
{}
306-
307-
private:
308-
bool runOnModule(Module &M) override;
309-
void getAnalysisUsage(AnalysisUsage &AU) const override
310-
{
311-
AU.addRequired<LoopInfoWrapperPass>();
312-
AU.addRequired<CallGraphWrapperPass>();
313-
AU.addPreserved<LoopInfoWrapperPass>();
314-
}
315-
friend struct CloneCtx;
316-
};
317-
318302
template<typename T>
319303
static inline std::vector<T*> consume_gv(Module &M, const char *name)
320304
{
@@ -335,18 +319,19 @@ static inline std::vector<T*> consume_gv(Module &M, const char *name)
335319
}
336320

337321
// Collect basic information about targets and functions.
338-
CloneCtx::CloneCtx(MultiVersioning *pass, Module &M)
322+
CloneCtx::CloneCtx(Module &M, function_ref<LoopInfo&()> GetLI, function_ref<CallGraph&()> GetCG)
339323
: ctx(M.getContext()),
340324
T_size(M.getDataLayout().getIntPtrType(ctx, 0)),
341325
T_int32(Type::getInt32Ty(ctx)),
342326
T_void(Type::getVoidTy(ctx)),
343327
T_psize(PointerType::get(T_size, 0)),
344328
tbaa_const(tbaa_make_child_with_context(ctx, "jtbaa_const", nullptr, true).first),
345-
pass(pass),
346329
specs(jl_get_llvm_clone_targets()),
347330
fvars(consume_gv<Function>(M, "jl_sysimg_fvars")),
348331
gvars(consume_gv<Constant>(M, "jl_sysimg_gvars")),
349-
M(M)
332+
M(M),
333+
GetLI(GetLI),
334+
GetCG(GetCG)
350335
{
351336
groups.emplace_back(0, specs[0]);
352337
uint32_t ntargets = specs.size();
@@ -449,7 +434,7 @@ bool CloneCtx::is_vector(FunctionType *ty) const
449434
uint32_t CloneCtx::collect_func_info(Function &F)
450435
{
451436
uint32_t flag = 0;
452-
if (!pass->getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo().empty())
437+
if (!GetLI().empty())
453438
flag |= JL_TARGET_CLONE_LOOP;
454439
if (is_vector(F.getFunctionType())) {
455440
flag |= JL_TARGET_CLONE_SIMD;
@@ -563,7 +548,7 @@ void CloneCtx::check_partial(Group &grp, Target &tgt)
563548
auto *next_set = &sets[1];
564549
// Reduce dispatch by expand the cloning set to functions that are directly called by
565550
// and calling cloned functions.
566-
auto &graph = pass->getAnalysis<CallGraphWrapperPass>().getCallGraph();
551+
auto &graph = GetCG();
567552
while (!cur_set->empty()) {
568553
for (auto orig_f: *cur_set) {
569554
// Use the uncloned function since it's already in the call graph
@@ -1079,7 +1064,7 @@ void CloneCtx::emit_metadata()
10791064
}
10801065
}
10811066

1082-
bool MultiVersioning::runOnModule(Module &M)
1067+
static bool runMultiVersioning(Module &M, function_ref<LoopInfo&()> GetLI, function_ref<CallGraph&()> GetCG)
10831068
{
10841069
// Group targets and identify cloning bases.
10851070
// Also initialize function info maps (we'll update these maps as we go)
@@ -1092,7 +1077,7 @@ bool MultiVersioning::runOnModule(Module &M)
10921077
if (M.getName() == "sysimage")
10931078
return false;
10941079

1095-
CloneCtx clone(this, M);
1080+
CloneCtx clone(M, GetLI, GetCG);
10961081

10971082
// Collect a list of original functions and clone base functions
10981083
clone.clone_bases();
@@ -1130,16 +1115,60 @@ bool MultiVersioning::runOnModule(Module &M)
11301115
return true;
11311116
}
11321117

1133-
char MultiVersioning::ID = 0;
1134-
static RegisterPass<MultiVersioning> X("JuliaMultiVersioning", "JuliaMultiVersioning Pass",
1118+
struct MultiVersioningLegacy: public ModulePass {
1119+
static char ID;
1120+
MultiVersioningLegacy()
1121+
: ModulePass(ID)
1122+
{}
1123+
1124+
private:
1125+
bool runOnModule(Module &M) override;
1126+
void getAnalysisUsage(AnalysisUsage &AU) const override
1127+
{
1128+
AU.addRequired<LoopInfoWrapperPass>();
1129+
AU.addRequired<CallGraphWrapperPass>();
1130+
AU.addPreserved<LoopInfoWrapperPass>();
1131+
}
1132+
};
1133+
1134+
bool MultiVersioningLegacy::runOnModule(Module &M)
1135+
{
1136+
auto GetLI = [this]() -> LoopInfo & {
1137+
return getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1138+
};
1139+
auto GetCG = [this]() -> CallGraph & {
1140+
return getAnalysis<CallGraphWrapperPass>().getCallGraph();
1141+
};
1142+
return runMultiVersioning(M, GetLI, GetCG);
1143+
}
1144+
1145+
1146+
char MultiVersioningLegacy::ID = 0;
1147+
static RegisterPass<MultiVersioningLegacy> X("JuliaMultiVersioning", "JuliaMultiVersioning Pass",
11351148
false /* Only looks at CFG */,
11361149
false /* Analysis Pass */);
11371150

1151+
} // anonymous namespace
1152+
1153+
PreservedAnalyses MultiVersioning::run(Module &M, ModuleAnalysisManager &AM)
1154+
{
1155+
auto GetLI = [&]() -> LoopInfo & {
1156+
return AM.getResult<LoopAnalysis>(M);;
1157+
};
1158+
auto GetCG = [&]() -> CallGraph & {
1159+
return AM.getResult<CallGraphAnalysis>(M);
1160+
};
1161+
if (runMultiVersioning(M, GetLI, GetCG)) {
1162+
auto preserved = PreservedAnalyses::allInSet<CFGAnalyses>();
1163+
preserved.preserve<LoopAnalysis>();
1164+
return preserved;
1165+
}
1166+
return PreservedAnalyses::all();
11381167
}
11391168

11401169
Pass *createMultiVersioningPass()
11411170
{
1142-
return new MultiVersioning();
1171+
return new MultiVersioningLegacy();
11431172
}
11441173

11451174
extern "C" JL_DLLEXPORT void LLVMExtraAddMultiVersioningPass_impl(LLVMPassManagerRef PM)

src/passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ struct FinalLowerGCPass : PassInfoMixin<LateLowerGC> {
6464
static bool isRequired() { return true; }
6565
};
6666

67+
struct MultiVersioning : PassInfoMixin<MultiVersioning> {
68+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
69+
static bool isRequired() { return true; }
70+
};
71+
6772
struct RemoveJuliaAddrspacesPass : PassInfoMixin<RemoveJuliaAddrspacesPass> {
6873
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
6974
static bool isRequired() { return true; }

0 commit comments

Comments
 (0)