2121#include  " llvm/ADT/SmallVector.h" 
2222#include  " llvm/ADT/Statistic.h" 
2323#include  " llvm/IR/CallingConv.h" 
24+ #include  " llvm/IR/GlobalAlias.h" 
2425#include  " llvm/IR/IRBuilder.h" 
2526#include  " llvm/IR/Instruction.h" 
2627#include  " llvm/IR/Mangler.h" 
@@ -69,15 +70,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
6970  Function *buildEntryThunk (Function *F);
7071  void  lowerCall (CallBase *CB);
7172  Function *buildGuestExitThunk (Function *F);
72-   bool  processFunction (Function &F, SetVector<Function *> &DirectCalledFns);
73+   Function *buildPatchableThunk (GlobalAlias *UnmangledAlias,
74+                                 GlobalAlias *MangledAlias);
75+   bool  processFunction (Function &F, SetVector<GlobalValue *> &DirectCalledFns,
76+                        DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
7377  bool  runOnModule (Module &M) override ;
7478
7579private: 
7680  int  cfguard_module_flag = 0 ;
7781  FunctionType *GuardFnType = nullptr ;
7882  PointerType *GuardFnPtrType = nullptr ;
83+   FunctionType *DispatchFnType = nullptr ;
84+   PointerType *DispatchFnPtrType = nullptr ;
7985  Constant *GuardFnCFGlobal = nullptr ;
8086  Constant *GuardFnGlobal = nullptr ;
87+   Constant *DispatchFnGlobal = nullptr ;
8188  Module *M = nullptr ;
8289
8390  Type *PtrTy;
@@ -671,6 +678,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
671678  return  GuestExit;
672679}
673680
681+ Function *
682+ AArch64Arm64ECCallLowering::buildPatchableThunk (GlobalAlias *UnmangledAlias,
683+                                                 GlobalAlias *MangledAlias) {
684+   llvm::raw_null_ostream NullThunkName;
685+   FunctionType *Arm64Ty, *X64Ty;
686+   Function *F = cast<Function>(MangledAlias->getAliasee ());
687+   SmallVector<ThunkArgTranslation> ArgTranslations;
688+   getThunkType (F->getFunctionType (), F->getAttributes (),
689+                Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
690+                ArgTranslations);
691+   std::string ThunkName (MangledAlias->getName ());
692+   if  (ThunkName[0 ] == ' ?' find (" @" 
693+     ThunkName.insert (ThunkName.find (" @" " $hybpatch_thunk" 
694+   } else  {
695+     ThunkName.append (" $hybpatch_thunk" 
696+   }
697+ 
698+   Function *GuestExit =
699+       Function::Create (Arm64Ty, GlobalValue::WeakODRLinkage, 0 , ThunkName, M);
700+   GuestExit->setComdat (M->getOrInsertComdat (ThunkName));
701+   GuestExit->setSection (" .wowthk$aa" 
702+   BasicBlock *BB = BasicBlock::Create (M->getContext (), " " 
703+   IRBuilder<> B (BB);
704+ 
705+   //  Load the global symbol as a pointer to the check function.
706+   LoadInst *DispatchLoad = B.CreateLoad (DispatchFnPtrType, DispatchFnGlobal);
707+ 
708+   //  Create new dispatch call instruction.
709+   Function *ExitThunk =
710+       buildExitThunk (F->getFunctionType (), F->getAttributes ());
711+   CallInst *Dispatch =
712+       B.CreateCall (DispatchFnType, DispatchLoad,
713+                    {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee ()});
714+ 
715+   //  Ensure that the first arguments are passed in the correct registers.
716+   Dispatch->setCallingConv (CallingConv::CFGuard_Check);
717+ 
718+   Value *DispatchRetVal = B.CreateBitCast (Dispatch, PtrTy);
719+   SmallVector<Value *> Args;
720+   for  (Argument &Arg : GuestExit->args ())
721+     Args.push_back (&Arg);
722+   CallInst *Call = B.CreateCall (Arm64Ty, DispatchRetVal, Args);
723+   Call->setTailCallKind (llvm::CallInst::TCK_MustTail);
724+ 
725+   if  (Call->getType ()->isVoidTy ())
726+     B.CreateRetVoid ();
727+   else 
728+     B.CreateRet (Call);
729+ 
730+   auto  SRetAttr = F->getAttributes ().getParamAttr (0 , Attribute::StructRet);
731+   auto  InRegAttr = F->getAttributes ().getParamAttr (0 , Attribute::InReg);
732+   if  (SRetAttr.isValid () && !InRegAttr.isValid ()) {
733+     GuestExit->addParamAttr (0 , SRetAttr);
734+     Call->addParamAttr (0 , SRetAttr);
735+   }
736+ 
737+   MangledAlias->setAliasee (GuestExit);
738+   return  GuestExit;
739+ }
740+ 
674741//  Lower an indirect call with inline code.
675742void  AArch64Arm64ECCallLowering::lowerCall (CallBase *CB) {
676743  assert (Triple (CB->getModule ()->getTargetTriple ()).isOSWindows () &&
@@ -726,17 +793,57 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
726793
727794  GuardFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy}, false );
728795  GuardFnPtrType = PointerType::get (GuardFnType, 0 );
796+   DispatchFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy, PtrTy}, false );
797+   DispatchFnPtrType = PointerType::get (DispatchFnType, 0 );
729798  GuardFnCFGlobal =
730799      M->getOrInsertGlobal (" __os_arm64x_check_icall_cfg" 
731800  GuardFnGlobal =
732801      M->getOrInsertGlobal (" __os_arm64x_check_icall" 
802+   DispatchFnGlobal =
803+       M->getOrInsertGlobal (" __os_arm64x_dispatch_call" 
804+ 
805+   DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
806+   SetVector<GlobalAlias *> PatchableFns;
733807
734-   SetVector<Function *> DirectCalledFns;
808+   for  (Function &F : Mod) {
809+     if  (!F.hasFnAttribute (Attribute::HybridPatchable) || F.isDeclaration () ||
810+         F.hasLocalLinkage () || F.getName ().ends_with (" $hp_target" 
811+       continue ;
812+ 
813+     //  Rename hybrid patchable functions and change callers to use a global
814+     //  alias instead.
815+     if  (std::optional<std::string> MangledName =
816+             getArm64ECMangledFunctionName (F.getName ().str ())) {
817+       std::string OrigName (F.getName ());
818+       F.setName (MangledName.value () + " $hp_target" 
819+ 
820+       //  The unmangled symbol is a weak alias to an undefined symbol with the
821+       //  "EXP+" prefix. This undefined symbol is resolved by the linker by
822+       //  creating an x86 thunk that jumps back to the actual EC target. Since we
823+       //  can't represent that in IR, we create an alias to the target instead.
824+       //  The "EXP+" symbol is set as metadata, which is then used by
825+       //  emitGlobalAlias to emit the right alias.
826+       auto  *A =
827+           GlobalAlias::create (GlobalValue::LinkOnceODRLinkage, OrigName, &F);
828+       F.replaceAllUsesWith (A);
829+       F.setMetadata (" arm64ec_exp_name" 
830+                     MDNode::get (M->getContext (),
831+                                 MDString::get (M->getContext (),
832+                                               " EXP+" value ())));
833+       A->setAliasee (&F);
834+ 
835+       FnsMap[A] = GlobalAlias::create (GlobalValue::LinkOnceODRLinkage,
836+                                       MangledName.value (), &F);
837+       PatchableFns.insert (A);
838+     }
839+   }
840+ 
841+   SetVector<GlobalValue *> DirectCalledFns;
735842  for  (Function &F : Mod)
736843    if  (!F.isDeclaration () &&
737844        F.getCallingConv () != CallingConv::ARM64EC_Thunk_Native &&
738845        F.getCallingConv () != CallingConv::ARM64EC_Thunk_X64)
739-       processFunction (F, DirectCalledFns);
846+       processFunction (F, DirectCalledFns, FnsMap );
740847
741848  struct  ThunkInfo  {
742849    Constant *Src;
@@ -754,14 +861,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
754861          {&F, buildEntryThunk (&F), Arm64ECThunkType::Entry});
755862    }
756863  }
757-   for  (Function *F : DirectCalledFns) {
864+   for  (GlobalValue *O : DirectCalledFns) {
865+     auto  GA = dyn_cast<GlobalAlias>(O);
866+     auto  F = dyn_cast<Function>(GA ? GA->getAliasee () : O);
758867    ThunkMapping.push_back (
759-         {F , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
868+         {O , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
760869         Arm64ECThunkType::Exit});
761-     if  (!F->hasDLLImportStorageClass ())
870+     if  (!GA && ! F->hasDLLImportStorageClass ())
762871      ThunkMapping.push_back (
763872          {buildGuestExitThunk (F), F, Arm64ECThunkType::GuestExit});
764873  }
874+   for  (GlobalAlias *A : PatchableFns) {
875+     Function *Thunk = buildPatchableThunk (A, FnsMap[A]);
876+     ThunkMapping.push_back ({Thunk, A, Arm64ECThunkType::GuestExit});
877+   }
765878
766879  if  (!ThunkMapping.empty ()) {
767880    SmallVector<Constant *> ThunkMappingArrayElems;
@@ -784,7 +897,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
784897}
785898
786899bool  AArch64Arm64ECCallLowering::processFunction (
787-     Function &F, SetVector<Function *> &DirectCalledFns) {
900+     Function &F, SetVector<GlobalValue *> &DirectCalledFns,
901+     DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
788902  SmallVector<CallBase *, 8 > IndirectCalls;
789903
790904  //  For ARM64EC targets, a function definition's name is mangled differently
@@ -836,6 +950,16 @@ bool AArch64Arm64ECCallLowering::processFunction(
836950        continue ;
837951      }
838952
953+       //  Use mangled global alias for direct calls to patchable functions.
954+       if  (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand ())) {
955+         auto  I = FnsMap.find (A);
956+         if  (I != FnsMap.end ()) {
957+           CB->setCalledOperand (I->second );
958+           DirectCalledFns.insert (I->first );
959+           continue ;
960+         }
961+       }
962+ 
839963      IndirectCalls.push_back (CB);
840964      ++Arm64ECCallsLowered;
841965    }
0 commit comments