2121#include  " llvm/ADT/SmallVector.h" 
2222#include  " llvm/ADT/Statistic.h" 
2323#include  " llvm/IR/CallingConv.h" 
24+ #include  " llvm/IR/GlobalAlias.h" 
2425#include  " llvm/IR/IRBuilder.h" 
2526#include  " llvm/IR/Instruction.h" 
2627#include  " llvm/IR/Mangler.h" 
@@ -70,15 +71,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
7071  Function *buildEntryThunk (Function *F);
7172  void  lowerCall (CallBase *CB);
7273  Function *buildGuestExitThunk (Function *F);
73-   bool  processFunction (Function &F, SetVector<Function *> &DirectCalledFns);
74+   Function *buildPatchableThunk (GlobalAlias *UnmangledAlias,
75+                                 GlobalAlias *MangledAlias);
76+   bool  processFunction (Function &F, SetVector<GlobalValue *> &DirectCalledFns,
77+                        DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
7478  bool  runOnModule (Module &M) override ;
7579
7680private: 
7781  int  cfguard_module_flag = 0 ;
7882  FunctionType *GuardFnType = nullptr ;
7983  PointerType *GuardFnPtrType = nullptr ;
84+   FunctionType *DispatchFnType = nullptr ;
85+   PointerType *DispatchFnPtrType = nullptr ;
8086  Constant *GuardFnCFGlobal = nullptr ;
8187  Constant *GuardFnGlobal = nullptr ;
88+   Constant *DispatchFnGlobal = nullptr ;
8289  Module *M = nullptr ;
8390
8491  Type *PtrTy;
@@ -672,6 +679,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
672679  return  GuestExit;
673680}
674681
682+ Function *
683+ AArch64Arm64ECCallLowering::buildPatchableThunk (GlobalAlias *UnmangledAlias,
684+                                                 GlobalAlias *MangledAlias) {
685+   llvm::raw_null_ostream NullThunkName;
686+   FunctionType *Arm64Ty, *X64Ty;
687+   Function *F = cast<Function>(MangledAlias->getAliasee ());
688+   SmallVector<ThunkArgTranslation> ArgTranslations;
689+   getThunkType (F->getFunctionType (), F->getAttributes (),
690+                Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
691+                ArgTranslations);
692+   std::string ThunkName (MangledAlias->getName ());
693+   if  (ThunkName[0 ] == ' ?' find (" @" 
694+     ThunkName.insert (ThunkName.find (" @" " $hybpatch_thunk" 
695+   } else  {
696+     ThunkName.append (" $hybpatch_thunk" 
697+   }
698+ 
699+   Function *GuestExit =
700+       Function::Create (Arm64Ty, GlobalValue::WeakODRLinkage, 0 , ThunkName, M);
701+   GuestExit->setComdat (M->getOrInsertComdat (ThunkName));
702+   GuestExit->setSection (" .wowthk$aa" 
703+   BasicBlock *BB = BasicBlock::Create (M->getContext (), " " 
704+   IRBuilder<> B (BB);
705+ 
706+   //  Load the global symbol as a pointer to the check function.
707+   LoadInst *DispatchLoad = B.CreateLoad (DispatchFnPtrType, DispatchFnGlobal);
708+ 
709+   //  Create new dispatch call instruction.
710+   Function *ExitThunk =
711+       buildExitThunk (F->getFunctionType (), F->getAttributes ());
712+   CallInst *Dispatch =
713+       B.CreateCall (DispatchFnType, DispatchLoad,
714+                    {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee ()});
715+ 
716+   //  Ensure that the first arguments are passed in the correct registers.
717+   Dispatch->setCallingConv (CallingConv::CFGuard_Check);
718+ 
719+   Value *DispatchRetVal = B.CreateBitCast (Dispatch, PtrTy);
720+   SmallVector<Value *> Args;
721+   for  (Argument &Arg : GuestExit->args ())
722+     Args.push_back (&Arg);
723+   CallInst *Call = B.CreateCall (Arm64Ty, DispatchRetVal, Args);
724+   Call->setTailCallKind (llvm::CallInst::TCK_MustTail);
725+ 
726+   if  (Call->getType ()->isVoidTy ())
727+     B.CreateRetVoid ();
728+   else 
729+     B.CreateRet (Call);
730+ 
731+   auto  SRetAttr = F->getAttributes ().getParamAttr (0 , Attribute::StructRet);
732+   auto  InRegAttr = F->getAttributes ().getParamAttr (0 , Attribute::InReg);
733+   if  (SRetAttr.isValid () && !InRegAttr.isValid ()) {
734+     GuestExit->addParamAttr (0 , SRetAttr);
735+     Call->addParamAttr (0 , SRetAttr);
736+   }
737+ 
738+   MangledAlias->setAliasee (GuestExit);
739+   return  GuestExit;
740+ }
741+ 
675742//  Lower an indirect call with inline code.
676743void  AArch64Arm64ECCallLowering::lowerCall (CallBase *CB) {
677744  assert (Triple (CB->getModule ()->getTargetTriple ()).isOSWindows () &&
@@ -727,17 +794,57 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
727794
728795  GuardFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy}, false );
729796  GuardFnPtrType = PointerType::get (GuardFnType, 0 );
797+   DispatchFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy, PtrTy}, false );
798+   DispatchFnPtrType = PointerType::get (DispatchFnType, 0 );
730799  GuardFnCFGlobal =
731800      M->getOrInsertGlobal (" __os_arm64x_check_icall_cfg" 
732801  GuardFnGlobal =
733802      M->getOrInsertGlobal (" __os_arm64x_check_icall" 
803+   DispatchFnGlobal =
804+       M->getOrInsertGlobal (" __os_arm64x_dispatch_call" 
805+ 
806+   DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
807+   SetVector<GlobalAlias *> PatchableFns;
734808
735-   SetVector<Function *> DirectCalledFns;
809+   for  (Function &F : Mod) {
810+     if  (!F.hasFnAttribute (Attribute::HybridPatchable) || F.isDeclaration () ||
811+         F.hasLocalLinkage () || F.getName ().ends_with (" $hp_target" 
812+       continue ;
813+ 
814+     //  Rename hybrid patchable functions and change callers to use a global
815+     //  alias instead.
816+     if  (std::optional<std::string> MangledName =
817+             getArm64ECMangledFunctionName (F.getName ().str ())) {
818+       std::string OrigName (F.getName ());
819+       F.setName (MangledName.value () + " $hp_target" 
820+ 
821+       //  The unmangled symbol is a weak alias to an undefined symbol with the
822+       //  "EXP+" prefix. This undefined symbol is resolved by the linker by
823+       //  creating an x86 thunk that jumps back to the actual EC target. Since we
824+       //  can't represent that in IR, we create an alias to the target instead.
825+       //  The "EXP+" symbol is set as metadata, which is then used by
826+       //  emitGlobalAlias to emit the right alias.
827+       auto  *A =
828+           GlobalAlias::create (GlobalValue::LinkOnceODRLinkage, OrigName, &F);
829+       F.replaceAllUsesWith (A);
830+       F.setMetadata (" arm64ec_exp_name" 
831+                     MDNode::get (M->getContext (),
832+                                 MDString::get (M->getContext (),
833+                                               " EXP+" value ())));
834+       A->setAliasee (&F);
835+ 
836+       FnsMap[A] = GlobalAlias::create (GlobalValue::LinkOnceODRLinkage,
837+                                       MangledName.value (), &F);
838+       PatchableFns.insert (A);
839+     }
840+   }
841+ 
842+   SetVector<GlobalValue *> DirectCalledFns;
736843  for  (Function &F : Mod)
737844    if  (!F.isDeclaration () &&
738845        F.getCallingConv () != CallingConv::ARM64EC_Thunk_Native &&
739846        F.getCallingConv () != CallingConv::ARM64EC_Thunk_X64)
740-       processFunction (F, DirectCalledFns);
847+       processFunction (F, DirectCalledFns, FnsMap );
741848
742849  struct  ThunkInfo  {
743850    Constant *Src;
@@ -755,14 +862,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
755862          {&F, buildEntryThunk (&F), Arm64ECThunkType::Entry});
756863    }
757864  }
758-   for  (Function *F : DirectCalledFns) {
865+   for  (GlobalValue *O : DirectCalledFns) {
866+     auto  GA = dyn_cast<GlobalAlias>(O);
867+     auto  F = dyn_cast<Function>(GA ? GA->getAliasee () : O);
759868    ThunkMapping.push_back (
760-         {F , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
869+         {O , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
761870         Arm64ECThunkType::Exit});
762-     if  (!F->hasDLLImportStorageClass ())
871+     if  (!GA && ! F->hasDLLImportStorageClass ())
763872      ThunkMapping.push_back (
764873          {buildGuestExitThunk (F), F, Arm64ECThunkType::GuestExit});
765874  }
875+   for  (GlobalAlias *A : PatchableFns) {
876+     Function *Thunk = buildPatchableThunk (A, FnsMap[A]);
877+     ThunkMapping.push_back ({Thunk, A, Arm64ECThunkType::GuestExit});
878+   }
766879
767880  if  (!ThunkMapping.empty ()) {
768881    SmallVector<Constant *> ThunkMappingArrayElems;
@@ -785,7 +898,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
785898}
786899
787900bool  AArch64Arm64ECCallLowering::processFunction (
788-     Function &F, SetVector<Function *> &DirectCalledFns) {
901+     Function &F, SetVector<GlobalValue *> &DirectCalledFns,
902+     DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
789903  SmallVector<CallBase *, 8 > IndirectCalls;
790904
791905  //  For ARM64EC targets, a function definition's name is mangled differently
@@ -837,6 +951,16 @@ bool AArch64Arm64ECCallLowering::processFunction(
837951        continue ;
838952      }
839953
954+       //  Use mangled global alias for direct calls to patchable functions.
955+       if  (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand ())) {
956+         auto  I = FnsMap.find (A);
957+         if  (I != FnsMap.end ()) {
958+           CB->setCalledOperand (I->second );
959+           DirectCalledFns.insert (I->first );
960+           continue ;
961+         }
962+       }
963+ 
840964      IndirectCalls.push_back (CB);
841965      ++Arm64ECCallsLowered;
842966    }
0 commit comments