@@ -46,6 +46,18 @@ static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
4646
4747namespace {
4848
49+ enum ThunkArgTranslation : uint8_t {
50+ Direct,
51+ Bitcast,
52+ PointerIndirection,
53+ };
54+
55+ struct ThunkArgInfo {
56+ Type *Arm64Ty;
57+ Type *X64Ty;
58+ ThunkArgTranslation Translation;
59+ };
60+
4961class AArch64Arm64ECCallLowering : public ModulePass {
5062public:
5163 static char ID;
@@ -74,25 +86,30 @@ class AArch64Arm64ECCallLowering : public ModulePass {
7486
7587 void getThunkType (FunctionType *FT, AttributeList AttrList,
7688 Arm64ECThunkType TT, raw_ostream &Out,
77- FunctionType *&Arm64Ty, FunctionType *&X64Ty);
89+ FunctionType *&Arm64Ty, FunctionType *&X64Ty,
90+ SmallVector<ThunkArgTranslation> &ArgTranslations);
7891 void getThunkRetType (FunctionType *FT, AttributeList AttrList,
7992 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
8093 SmallVectorImpl<Type *> &Arm64ArgTypes,
81- SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
94+ SmallVectorImpl<Type *> &X64ArgTypes,
95+ SmallVector<ThunkArgTranslation> &ArgTranslations,
96+ bool &HasSretPtr);
8297 void getThunkArgTypes (FunctionType *FT, AttributeList AttrList,
8398 Arm64ECThunkType TT, raw_ostream &Out,
8499 SmallVectorImpl<Type *> &Arm64ArgTypes,
85- SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
86- void canonicalizeThunkType (Type *T, Align Alignment, bool Ret,
87- uint64_t ArgSizeBytes, raw_ostream &Out,
88- Type *&Arm64Ty, Type *&X64Ty);
100+ SmallVectorImpl<Type *> &X64ArgTypes,
101+ SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,
102+ bool HasSretPtr);
103+ ThunkArgInfo canonicalizeThunkType (Type *T, Align Alignment, bool Ret,
104+ uint64_t ArgSizeBytes, raw_ostream &Out);
89105};
90106
91107} // end anonymous namespace
92108
93109void AArch64Arm64ECCallLowering::getThunkType (
94110 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
95- raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty) {
111+ raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
112+ SmallVector<ThunkArgTranslation> &ArgTranslations) {
96113 Out << (TT == Arm64ECThunkType::Entry ? " $ientry_thunk$cdecl$"
97114 : " $iexit_thunk$cdecl$" );
98115
@@ -111,10 +128,10 @@ void AArch64Arm64ECCallLowering::getThunkType(
111128
112129 bool HasSretPtr = false ;
113130 getThunkRetType (FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
114- X64ArgTypes, HasSretPtr);
131+ X64ArgTypes, ArgTranslations, HasSretPtr);
115132
116133 getThunkArgTypes (FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
117- HasSretPtr);
134+ ArgTranslations, HasSretPtr);
118135
119136 Arm64Ty = FunctionType::get (Arm64RetTy, Arm64ArgTypes, false );
120137
@@ -124,7 +141,8 @@ void AArch64Arm64ECCallLowering::getThunkType(
124141void AArch64Arm64ECCallLowering::getThunkArgTypes (
125142 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
126143 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
127- SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
144+ SmallVectorImpl<Type *> &X64ArgTypes,
145+ SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
128146
129147 Out << " $" ;
130148 if (FT->isVarArg ()) {
@@ -153,17 +171,20 @@ void AArch64Arm64ECCallLowering::getThunkArgTypes(
153171 for (int i = HasSretPtr ? 1 : 0 ; i < 4 ; i++) {
154172 Arm64ArgTypes.push_back (I64Ty);
155173 X64ArgTypes.push_back (I64Ty);
174+ ArgTranslations.push_back (ThunkArgTranslation::Direct);
156175 }
157176
158177 // x4
159178 Arm64ArgTypes.push_back (PtrTy);
160179 X64ArgTypes.push_back (PtrTy);
180+ ArgTranslations.push_back (ThunkArgTranslation::Direct);
161181 // x5
162182 Arm64ArgTypes.push_back (I64Ty);
163183 if (TT != Arm64ECThunkType::Entry) {
164184 // FIXME: x5 isn't actually used by the x64 side; revisit once we
165185 // have proper isel for varargs
166186 X64ArgTypes.push_back (I64Ty);
187+ ArgTranslations.push_back (ThunkArgTranslation::Direct);
167188 }
168189 return ;
169190 }
@@ -187,18 +208,20 @@ void AArch64Arm64ECCallLowering::getThunkArgTypes(
187208 uint64_t ArgSizeBytes = 0 ;
188209 Align ParamAlign = Align ();
189210#endif
190- Type * Arm64Ty, * X64Ty;
191- canonicalizeThunkType (FT->getParamType (I), ParamAlign,
192- /* Ret*/ false , ArgSizeBytes, Out, Arm64Ty, X64Ty );
211+ auto [ Arm64Ty, X64Ty, ArgTranslation] =
212+ canonicalizeThunkType (FT->getParamType (I), ParamAlign,
213+ /* Ret*/ false , ArgSizeBytes, Out);
193214 Arm64ArgTypes.push_back (Arm64Ty);
194215 X64ArgTypes.push_back (X64Ty);
216+ ArgTranslations.push_back (ArgTranslation);
195217 }
196218}
197219
198220void AArch64Arm64ECCallLowering::getThunkRetType (
199221 FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
200222 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
201- SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
223+ SmallVectorImpl<Type *> &X64ArgTypes,
224+ SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
202225 Type *T = FT->getReturnType ();
203226#if 0
204227 // FIXME: Need more information about argument size; see
@@ -240,13 +263,13 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
240263 // that's a miscompile.)
241264 Type *SRetType = SRetAttr0.getValueAsType ();
242265 Align SRetAlign = AttrList.getParamAlignment (0 ).valueOrOne ();
243- Type *Arm64Ty, *X64Ty;
244266 canonicalizeThunkType (SRetType, SRetAlign, /* Ret*/ true , ArgSizeBytes,
245- Out, Arm64Ty, X64Ty );
267+ Out);
246268 Arm64RetTy = VoidTy;
247269 X64RetTy = VoidTy;
248270 Arm64ArgTypes.push_back (FT->getParamType (0 ));
249271 X64ArgTypes.push_back (FT->getParamType (0 ));
272+ ArgTranslations.push_back (ThunkArgTranslation::Direct);
250273 HasSretPtr = true ;
251274 return ;
252275 }
@@ -258,8 +281,10 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
258281 return ;
259282 }
260283
261- canonicalizeThunkType (T, Align (), /* Ret*/ true , ArgSizeBytes, Out, Arm64RetTy,
262- X64RetTy);
284+ auto info =
285+ canonicalizeThunkType (T, Align (), /* Ret*/ true , ArgSizeBytes, Out);
286+ Arm64RetTy = info.Arm64Ty ;
287+ X64RetTy = info.X64Ty ;
263288 if (X64RetTy->isPointerTy ()) {
264289 // If the X64 type is canonicalized to a pointer, that means it's
265290 // passed/returned indirectly. For a return value, that means it's an
@@ -269,21 +294,33 @@ void AArch64Arm64ECCallLowering::getThunkRetType(
269294 }
270295}
271296
272- void AArch64Arm64ECCallLowering::canonicalizeThunkType (
273- Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
274- Type *&Arm64Ty, Type *&X64Ty) {
297+ ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType (
298+ Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
299+ raw_ostream &Out) {
300+
301+ auto direct = [](Type *T) {
302+ return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
303+ };
304+
305+ auto bitcast = [this ](Type *Arm64Ty, uint64_t SizeInBytes) {
306+ return ThunkArgInfo{Arm64Ty,
307+ llvm::Type::getIntNTy (M->getContext (), SizeInBytes * 8 ),
308+ ThunkArgTranslation::Bitcast};
309+ };
310+
311+ auto pointerIndirection = [this ](Type *Arm64Ty) {
312+ return ThunkArgInfo{Arm64Ty, PtrTy,
313+ ThunkArgTranslation::PointerIndirection};
314+ };
315+
275316 if (T->isFloatTy ()) {
276317 Out << " f" ;
277- Arm64Ty = T;
278- X64Ty = T;
279- return ;
318+ return direct (T);
280319 }
281320
282321 if (T->isDoubleTy ()) {
283322 Out << " d" ;
284- Arm64Ty = T;
285- X64Ty = T;
286- return ;
323+ return direct (T);
287324 }
288325
289326 if (T->isFloatingPointTy ()) {
@@ -306,16 +343,14 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
306343 Out << (ElementTy->isFloatTy () ? " F" : " D" ) << TotalSizeBytes;
307344 if (Alignment.value () >= 16 && !Ret)
308345 Out << " a" << Alignment.value ();
309- Arm64Ty = T;
310346 if (TotalSizeBytes <= 8 ) {
311347 // Arm64 returns small structs of float/double in float registers;
312348 // X64 uses RAX.
313- X64Ty = llvm::Type::getIntNTy (M-> getContext () , TotalSizeBytes * 8 );
349+ return bitcast (T , TotalSizeBytes);
314350 } else {
315351 // Struct is passed directly on Arm64, but indirectly on X64.
316- X64Ty = PtrTy ;
352+ return pointerIndirection (T) ;
317353 }
318- return ;
319354 } else if (T->isFloatingPointTy ()) {
320355 report_fatal_error (" Only 32 and 64 bit floating points are supported for "
321356 " ARM64EC thunks" );
@@ -324,9 +359,7 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
324359
325360 if ((T->isIntegerTy () || T->isPointerTy ()) && DL.getTypeSizeInBits (T) <= 64 ) {
326361 Out << " i8" ;
327- Arm64Ty = I64Ty;
328- X64Ty = I64Ty;
329- return ;
362+ return direct (I64Ty);
330363 }
331364
332365 unsigned TypeSize = ArgSizeBytes;
@@ -338,13 +371,12 @@ void AArch64Arm64ECCallLowering::canonicalizeThunkType(
338371 if (Alignment.value () >= 16 && !Ret)
339372 Out << " a" << Alignment.value ();
340373 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
341- Arm64Ty = T;
342374 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8 ) {
343375 // Pass directly in an integer register
344- X64Ty = llvm::Type::getIntNTy (M-> getContext () , TypeSize * 8 );
376+ return bitcast (T , TypeSize);
345377 } else {
346378 // Passed directly on Arm64, but indirectly on X64.
347- X64Ty = PtrTy ;
379+ return pointerIndirection (T) ;
348380 }
349381}
350382
@@ -355,8 +387,9 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
355387 SmallString<256 > ExitThunkName;
356388 llvm::raw_svector_ostream ExitThunkStream (ExitThunkName);
357389 FunctionType *Arm64Ty, *X64Ty;
390+ SmallVector<ThunkArgTranslation> ArgTranslations;
358391 getThunkType (FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
359- X64Ty);
392+ X64Ty, ArgTranslations );
360393 if (Function *F = M->getFunction (ExitThunkName))
361394 return F;
362395
@@ -387,6 +420,7 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
387420 SmallVector<Value *> Args;
388421
389422 // Pass the called function in x9.
423+ auto X64TyOffset = 1 ;
390424 Args.push_back (F->arg_begin ());
391425
392426 Type *RetTy = Arm64Ty->getReturnType ();
@@ -396,10 +430,14 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
396430 // pointer.
397431 if (DL.getTypeStoreSize (RetTy) > 8 ) {
398432 Args.push_back (IRB.CreateAlloca (RetTy));
433+ X64TyOffset++;
399434 }
400435 }
401436
402- for (auto &Arg : make_range (F->arg_begin () + 1 , F->arg_end ())) {
437+ for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal (
438+ make_range (F->arg_begin () + 1 , F->arg_end ()),
439+ make_range (X64Ty->param_begin () + X64TyOffset, X64Ty->param_end ()),
440+ ArgTranslations)) {
403441 // Translate arguments from AArch64 calling convention to x86 calling
404442 // convention.
405443 //
@@ -414,18 +452,20 @@ Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
414452 // with an attribute.)
415453 //
416454 // The first argument is the called function, stored in x9.
417- if (Arg.getType ()->isArrayTy () || Arg.getType ()->isStructTy () ||
418- DL.getTypeStoreSize (Arg.getType ()) > 8 ) {
455+ if (ArgTranslation != ThunkArgTranslation::Direct) {
419456 Value *Mem = IRB.CreateAlloca (Arg.getType ());
420457 IRB.CreateStore (&Arg, Mem);
421- if (DL. getTypeStoreSize (Arg. getType ()) <= 8 ) {
458+ if (ArgTranslation == ThunkArgTranslation::Bitcast ) {
422459 Type *IntTy = IRB.getIntNTy (DL.getTypeStoreSizeInBits (Arg.getType ()));
423460 Args.push_back (IRB.CreateLoad (IntTy, IRB.CreateBitCast (Mem, PtrTy)));
424- } else
461+ } else {
462+ assert (ArgTranslation == ThunkArgTranslation::PointerIndirection);
425463 Args.push_back (Mem);
464+ }
426465 } else {
427466 Args.push_back (&Arg);
428467 }
468+ assert (Args.back ()->getType () == X64ArgType);
429469 }
430470 // FIXME: Transfer necessary attributes? sret? anything else?
431471
@@ -459,8 +499,10 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
459499 SmallString<256 > EntryThunkName;
460500 llvm::raw_svector_ostream EntryThunkStream (EntryThunkName);
461501 FunctionType *Arm64Ty, *X64Ty;
502+ SmallVector<ThunkArgTranslation> ArgTranslations;
462503 getThunkType (F->getFunctionType (), F->getAttributes (),
463- Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty);
504+ Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
505+ ArgTranslations);
464506 if (Function *F = M->getFunction (EntryThunkName))
465507 return F;
466508
@@ -472,7 +514,6 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
472514 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
473515 Thunk->addFnAttr (" frame-pointer" , " all" );
474516
475- auto &DL = M->getDataLayout ();
476517 BasicBlock *BB = BasicBlock::Create (M->getContext (), " " , Thunk);
477518 IRBuilder<> IRB (BB);
478519
@@ -481,24 +522,28 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
481522
482523 bool TransformDirectToSRet = X64RetType->isVoidTy () && !RetTy->isVoidTy ();
483524 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1 ;
484- unsigned PassthroughArgSize = F->isVarArg () ? 5 : Thunk->arg_size ();
525+ unsigned PassthroughArgSize =
526+ (F->isVarArg () ? 5 : Thunk->arg_size ()) - ThunkArgOffset;
527+ assert (ArgTranslations.size () == F->isVarArg () ? 5 : PassthroughArgSize);
485528
486529 // Translate arguments to call.
487530 SmallVector<Value *> Args;
488- for (unsigned i = ThunkArgOffset, e = PassthroughArgSize ; i != e ; ++i) {
489- Value *Arg = Thunk->getArg (i);
490- Type *ArgTy = Arm64Ty->getParamType (i - ThunkArgOffset );
491- if (ArgTy-> isArrayTy () || ArgTy-> isStructTy () ||
492- DL. getTypeStoreSize (ArgTy) > 8 ) {
531+ for (unsigned i = 0 ; i != PassthroughArgSize ; ++i) {
532+ Value *Arg = Thunk->getArg (i + ThunkArgOffset );
533+ Type *ArgTy = Arm64Ty->getParamType (i);
534+ ThunkArgTranslation ArgTranslation = ArgTranslations[i];
535+ if (ArgTranslation != ThunkArgTranslation::Direct ) {
493536 // Translate array/struct arguments to the expected type.
494- if (DL. getTypeStoreSize (ArgTy) <= 8 ) {
537+ if (ArgTranslation == ThunkArgTranslation::Bitcast ) {
495538 Value *CastAlloca = IRB.CreateAlloca (ArgTy);
496539 IRB.CreateStore (Arg, IRB.CreateBitCast (CastAlloca, PtrTy));
497540 Arg = IRB.CreateLoad (ArgTy, CastAlloca);
498541 } else {
542+ assert (ArgTranslation == ThunkArgTranslation::PointerIndirection);
499543 Arg = IRB.CreateLoad (ArgTy, IRB.CreateBitCast (Arg, PtrTy));
500544 }
501545 }
546+ assert (Arg->getType () == ArgTy);
502547 Args.push_back (Arg);
503548 }
504549
@@ -558,8 +603,10 @@ Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
558603Function *AArch64Arm64ECCallLowering::buildGuestExitThunk (Function *F) {
559604 llvm::raw_null_ostream NullThunkName;
560605 FunctionType *Arm64Ty, *X64Ty;
606+ SmallVector<ThunkArgTranslation> ArgTranslations;
561607 getThunkType (F->getFunctionType (), F->getAttributes (),
562- Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
608+ Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
609+ ArgTranslations);
563610 auto MangledName = getArm64ECMangledFunctionName (F->getName ().str ());
564611 assert (MangledName && " Can't guest exit to function that's already native" );
565612 std::string ThunkName = *MangledName;
0 commit comments