Skip to content

Conversation

@pcc
Copy link
Contributor

@pcc pcc commented Mar 28, 2025

Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

Created using spr 1.3.6-beta.1
@pcc pcc requested a review from nikic as a code owner March 28, 2025 22:34
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AArch64 clang:codegen IR generation bugs: mangling, exceptions, etc. llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:transforms llvm:SandboxIR labels Mar 28, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-llvm-transforms

Author: Peter Collingbourne (pcc)

Changes

Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:

  • Fix broken test.
  • Add bitcode and IR writer support.
  • Add tests.

Patch is 22.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133537.diff

16 Files Affected:

  • (modified) clang/lib/CodeGen/CGPointerAuth.cpp (+3-3)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+1)
  • (modified) llvm/include/llvm/IR/Constants.h (+9-4)
  • (modified) llvm/include/llvm/SandboxIR/Constant.h (+4-1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+21-8)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+17-1)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+3-1)
  • (modified) llvm/lib/IR/Constants.cpp (+8-4)
  • (modified) llvm/lib/IR/ConstantsContext.h (+2-1)
  • (modified) llvm/lib/IR/Core.cpp (+3-1)
  • (modified) llvm/lib/SandboxIR/Constant.cpp (+9-2)
  • (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+31-6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+2-2)
  • (modified) llvm/lib/Transforms/Utils/ValueMapper.cpp (+3-2)
  • (modified) llvm/unittests/SandboxIR/SandboxIRTest.cpp (+1-1)
  • (modified) llvm/unittests/Transforms/Utils/ValueMapperTest.cpp (+9-4)
diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp
index 4b032306ead72..2d72fef470af6 100644
--- a/clang/lib/CodeGen/CGPointerAuth.cpp
+++ b/clang/lib/CodeGen/CGPointerAuth.cpp
@@ -308,9 +308,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key,
     IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0);
   }
 
-  return llvm::ConstantPtrAuth::get(Pointer,
-                                    llvm::ConstantInt::get(Int32Ty, Key),
-                                    IntegerDiscriminator, AddressDiscriminator);
+  return llvm::ConstantPtrAuth::get(
+      Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator,
+      AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy));
 }
 
 /// Does a given PointerAuthScheme require us to sign a value
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index ec2535ac85966..13521ba6cd00f 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -431,6 +431,7 @@ enum ConstantsCodes {
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
   CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
   CST_CODE_PTRAUTH = 33,              // [ptr, key, disc, addrdisc]
+  CST_CODE_PTRAUTH2 = 34,             // [ptr, key, disc, addrdisc, DeactivationSymbol]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index a50217078d0ed..45d5352bf06a6 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1022,10 +1022,10 @@ class ConstantPtrAuth final : public Constant {
   friend struct ConstantPtrAuthKeyType;
   friend class Constant;
 
-  constexpr static IntrusiveOperandsAllocMarker AllocMarker{4};
+  constexpr static IntrusiveOperandsAllocMarker AllocMarker{5};
 
   ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc,
-                  Constant *AddrDisc);
+                  Constant *AddrDisc, Constant *DeactivationSymbol);
 
   void *operator new(size_t s) { return User::operator new(s, AllocMarker); }
 
@@ -1035,7 +1035,8 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer signed with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              ConstantInt *Disc, Constant *AddrDisc);
+                              ConstantInt *Disc, Constant *AddrDisc,
+                              Constant *DeactivationSymbol);
 
   /// Produce a new ptrauth expression signing the given value using
   /// the same schema as is stored in one.
@@ -1067,6 +1068,10 @@ class ConstantPtrAuth final : public Constant {
     return !getAddrDiscriminator()->isNullValue();
   }
 
+  Constant *getDeactivationSymbol() const {
+    return cast<Constant>(Op<4>().get());
+  }
+
   /// A constant value for the address discriminator which has special
   /// significance to ctors/dtors lowering. Regular address discrimination can't
   /// be applied for them since uses of llvm.global_{c|d}tors are disallowed
@@ -1094,7 +1099,7 @@ class ConstantPtrAuth final : public Constant {
 
 template <>
 struct OperandTraits<ConstantPtrAuth>
-    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+    : public FixedNumOperandTraits<ConstantPtrAuth, 5> {};
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
 
diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h
index 17f55e973cd76..5243a9476ac64 100644
--- a/llvm/include/llvm/SandboxIR/Constant.h
+++ b/llvm/include/llvm/SandboxIR/Constant.h
@@ -1096,7 +1096,8 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer signed with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              ConstantInt *Disc, Constant *AddrDisc);
+                              ConstantInt *Disc, Constant *AddrDisc,
+                              Constant *DeactivationSymbol);
   /// The pointer that is signed in this ptrauth signed pointer.
   Constant *getPointer() const;
 
@@ -1111,6 +1112,8 @@ class ConstantPtrAuth final : public Constant {
   /// the only global-initializer user of the ptrauth signed pointer.
   Constant *getAddrDiscriminator() const;
 
+  Constant *getDeactivationSymbol() const;
+
   /// Whether there is any non-null address discriminator.
   bool hasAddressDiscriminator() const {
     return cast<llvm::ConstantPtrAuth>(Val)->hasAddressDiscriminator();
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 960119bab0933..dfa014aa0bd7d 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4226,11 +4226,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   }
   case lltok::kw_ptrauth: {
     // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
-    //                         (',' i64 <disc> (',' ptr addrdisc)? )? ')'
+    //                         (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)? )? )? ')'
     Lex.Lex();
 
     Constant *Ptr, *Key;
-    Constant *Disc = nullptr, *AddrDisc = nullptr;
+    Constant *Disc = nullptr, *AddrDisc = nullptr,
+             *DeactivationSymbol = nullptr;
 
     if (parseToken(lltok::lparen,
                    "expected '(' in constant ptrauth expression") ||
@@ -4239,11 +4240,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
                    "expected comma in constant ptrauth expression") ||
         parseGlobalTypeAndValue(Key))
       return true;
-    // If present, parse the optional disc/addrdisc.
-    if (EatIfPresent(lltok::comma))
-      if (parseGlobalTypeAndValue(Disc) ||
-          (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)))
-        return true;
+    // If present, parse the optional disc/addrdisc/ds.
+    if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc))
+      return true;
+    if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))
+      return true;
+    if (EatIfPresent(lltok::comma) &&
+        parseGlobalTypeAndValue(DeactivationSymbol))
+      return true;
     if (parseToken(lltok::rparen,
                    "expected ')' in constant ptrauth expression"))
       return true;
@@ -4274,7 +4278,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
       AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
     }
 
-    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc);
+    if (DeactivationSymbol) {
+      if (!DeactivationSymbol->getType()->isPointerTy())
+        return error(
+            ID.Loc, "constant ptrauth deactivation symbol must be a pointer");
+    } else {
+      DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0));
+    }
+
+    ID.ConstantVal =
+        ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol);
     ID.Kind = ValID::t_Constant;
     return false;
   }
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 40e755902b724..c09c3b4f7d38c 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1611,7 +1611,13 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
           if (!Disc)
             return error("ptrauth disc operand must be ConstantInt");
 
-          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
+          auto *DeactivationSymbol =
+              ConstOps.size() > 4 ? ConstOps[4]
+                                  : ConstantPointerNull::get(cast<PointerType>(
+                                        ConstOps[3]->getType()));
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3],
+                                   DeactivationSymbol);
           break;
         }
         case BitcodeConstant::NoCFIOpcode: {
@@ -3811,6 +3817,16 @@ Error BitcodeReader::parseConstants() {
                                    (unsigned)Record[2], (unsigned)Record[3]});
       break;
     }
+    case bitc::CST_CODE_PTRAUTH2: {
+      if (Record.size() < 4)
+        return error("Invalid ptrauth record");
+      // Ptr, Key, Disc, AddrDisc, DeactivationSymbol
+      V = BitcodeConstant::create(
+          Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
+          {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2],
+           (unsigned)Record[3], (unsigned)Record[4]});
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 79547b299a903..5efb321967008 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1630,12 +1630,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
   if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
     Out << "ptrauth (";
 
-    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)
+    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?)
     unsigned NumOpsToWrite = 2;
     if (!CPA->getOperand(2)->isNullValue())
       NumOpsToWrite = 3;
     if (!CPA->getOperand(3)->isNullValue())
       NumOpsToWrite = 4;
+    if (!CPA->getOperand(4)->isNullValue())
+      NumOpsToWrite = 5;
 
     ListSeparator LS;
     for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) {
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index fb659450bfeeb..007d36d19f373 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2072,19 +2072,22 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
 //
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      ConstantInt *Disc, Constant *AddrDisc) {
-  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc};
+                                      ConstantInt *Disc, Constant *AddrDisc,
+                                      Constant *DeactivationSymbol) {
+  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol};
   ConstantPtrAuthKeyType MapKey(ArgVec);
   LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
   return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
 }
 
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
-  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator());
+  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(),
+             getDeactivationSymbol());
 }
 
 ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
-                                 ConstantInt *Disc, Constant *AddrDisc)
+                                 ConstantInt *Disc, Constant *AddrDisc,
+                                 Constant *DeactivationSymbol)
     : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) {
   assert(Ptr->getType()->isPointerTy());
   assert(Key->getBitWidth() == 32);
@@ -2094,6 +2097,7 @@ ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
   setOperand(1, Key);
   setOperand(2, Disc);
   setOperand(3, AddrDisc);
+  setOperand(4, DeactivationSymbol);
 }
 
 /// Remove the constant from the constant table.
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index e5c9622e09927..bf9d8ab952271 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -545,7 +545,8 @@ struct ConstantPtrAuthKeyType {
 
   ConstantPtrAuth *create(TypeClass *Ty) const {
     return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
-                               cast<ConstantInt>(Operands[2]), Operands[3]);
+                               cast<ConstantInt>(Operands[2]), Operands[3],
+                               Operands[4]);
   }
 };
 
diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index f4b03e8cb8aa3..6190ebdac16d4 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -1687,7 +1687,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key,
                                  LLVMValueRef Disc, LLVMValueRef AddrDisc) {
   return wrap(ConstantPtrAuth::get(
       unwrap<Constant>(Ptr), unwrap<ConstantInt>(Key),
-      unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc)));
+      unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc),
+      ConstantPointerNull::get(
+          cast<PointerType>(unwrap<Constant>(AddrDisc)->getType()))));
 }
 
 /*-- Opcode mapping */
diff --git a/llvm/lib/SandboxIR/Constant.cpp b/llvm/lib/SandboxIR/Constant.cpp
index 3e13c935c4281..0a28cf9feeb4d 100644
--- a/llvm/lib/SandboxIR/Constant.cpp
+++ b/llvm/lib/SandboxIR/Constant.cpp
@@ -421,10 +421,12 @@ PointerType *NoCFIValue::getType() const {
 }
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      ConstantInt *Disc, Constant *AddrDisc) {
+                                      ConstantInt *Disc, Constant *AddrDisc,
+                                      Constant *DeactivationSymbol) {
   auto *LLVMC = llvm::ConstantPtrAuth::get(
       cast<llvm::Constant>(Ptr->Val), cast<llvm::ConstantInt>(Key->Val),
-      cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val));
+      cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val),
+      cast<llvm::Constant>(DeactivationSymbol->Val));
   return cast<ConstantPtrAuth>(Ptr->getContext().getOrCreateConstant(LLVMC));
 }
 
@@ -448,6 +450,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const {
       cast<llvm::ConstantPtrAuth>(Val)->getAddrDiscriminator());
 }
 
+Constant *ConstantPtrAuth::getDeactivationSymbol() const {
+  return Ctx.getOrCreateConstant(
+      cast<llvm::ConstantPtrAuth>(Val)->getDeactivationSymbol());
+}
+
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
   auto *LLVMC = cast<llvm::ConstantPtrAuth>(Val)->getWithSameSchema(
       cast<llvm::Constant>(Pointer->Val));
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 135f6cff0f78b..283493408699e 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -195,7 +195,7 @@ class AArch64AsmPrinter : public AsmPrinter {
 
   const MCExpr *emitPAuthRelocationAsIRelative(
       const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
-      bool HasAddressDiversity, bool IsDSOLocal);
+      bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr);
 
   /// tblgen'erated driver function for lowering simple MI->MC
   /// pseudo instructions.
@@ -2270,15 +2270,17 @@ static void emitAddress(MCStreamer &Streamer, MCRegister Reg,
 }
 
 static bool targetSupportsPAuthRelocation(const Triple &TT,
-                                          const MCExpr *Target) {
+                                          const MCExpr *Target,
+                                          const MCExpr *DSExpr) {
   // No released version of glibc supports PAuth relocations.
   if (TT.isOSGlibc())
     return false;
 
   // We emit PAuth constants as IRELATIVE relocations in cases where the
   // constant cannot be represented as a PAuth relocation:
-  // 1) The signed value is not a symbol.
-  return !isa<MCConstantExpr>(Target);
+  // 1) There is a deactivation symbol.
+  // 2) The signed value is not a symbol.
+  return !DSExpr && !isa<MCConstantExpr>(Target);
 }
 
 static bool targetSupportsIRelativeRelocation(const Triple &TT) {
@@ -2295,7 +2297,7 @@ static bool targetSupportsIRelativeRelocation(const Triple &TT) {
 
 const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
     const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
-    bool HasAddressDiversity, bool IsDSOLocal) {
+    bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) {
   const Triple &TT = TM.getTargetTriple();
 
   // We only emit an IRELATIVE relocation if the target supports IRELATIVE and
@@ -2358,6 +2360,18 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
       MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext());
   OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef),
                                *STI);
+
+  if (DSExpr) {
+    auto *PrePACInstExpr =
+        MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext());
+    OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_INST32", DSExpr,
+                                    SMLoc(), *STI);
+  }
+
+  // We need a RET despite the above tail call because the deactivation symbol
+  // may replace it with a NOP.
+  OutStreamer->emitInstruction(MCInstBuilder(AArch64::RET).addReg(AArch64::LR),
+                               *STI);
   OutStreamer->popSection();
 
   return MCSymbolRefExpr::create(IFuncSym, OutStreamer->getContext());
@@ -2388,6 +2402,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
     Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx);
   }
 
+  const MCExpr *DSExpr = nullptr;
+  if (auto *DS = dyn_cast<GlobalValue>(CPA.getDeactivationSymbol())) {
+    if (isa<GlobalAlias>(DS))
+      return Sym;
+    DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx);
+  }
+
   uint64_t KeyID = CPA.getKey()->getZExtValue();
   // We later rely on valid KeyID value in AArch64PACKeyIDToString call from
   // AArch64AuthMCExpr::printImpl, so fail fast.
@@ -2404,9 +2425,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
   // Check if we need to represent this with an IRELATIVE and emit it if so.
   if (auto *IFuncSym = emitPAuthRelocationAsIRelative(
           Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(),
-          BaseGVB && BaseGVB->isDSOLocal()))
+          BaseGVB && BaseGVB->isDSOLocal(), DSExpr))
     return IFuncSym;
 
+  if (DSExpr)
+    report_fatal_error("deactivation symbols unsupported in constant "
+                       "expressions on this target");
+
   // Finally build the complete @AUTH expr.
   return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID),
                                    CPA.hasAddressDiscriminator(), Ctx);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 12dd4cec85f59..58b98d8d93464 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2946,9 +2946,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
         auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
         auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
-        auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
+        auto *Null = ConstantPointerNull::get(Builder.getPtrTy());
         auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
-                                            SignDisc, SignAddrDisc);
+                                            SignDisc, Null, Null);
         replaceInstUsesWith(
             *II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
         return eraseInstFromFunction(*II);
diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp
index 5e50536a99206..320bef6c8f240 100644
--- a/llvm/lib/Transforms/Utils/ValueMapper.cpp
+++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp
@@ -526,8 +526,9 @@ Value *Mapper::mapValue(const Value *V) {
   if (isa<ConstantVector>(C))
     return getVM()[V] = ConstantVector::get(Ops);
   if (isa<ConstantPtrAuth>(C))
-    return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
-                                             cast<ConstantInt>(Ops[2]), Ops[3]);
+    return getVM()[V] =
+               ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
+                                    cast<ConstantInt>(Ops[2]), Ops[3], Ops[4]);
   // If this is a no-operand constant, it must be because the type was remapped.
   if (isa<PoisonValue>(C))
     return getVM()[V] = PoisonValue::get(NewTy);
diff --git a/llvm/unittests/SandboxIR/SandboxIR...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-llvm-ir

Author: Peter Collingbourne (pcc)

Changes

Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:

  • Fix broken test.
  • Add bitcode and IR writer support.
  • Add tests.

Patch is 22.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133537.diff

16 Files Affected:

  • (modified) clang/lib/CodeGen/CGPointerAuth.cpp (+3-3)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+1)
  • (modified) llvm/include/llvm/IR/Constants.h (+9-4)
  • (modified) llvm/include/llvm/SandboxIR/Constant.h (+4-1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+21-8)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+17-1)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+3-1)
  • (modified) llvm/lib/IR/Constants.cpp (+8-4)
  • (modified) llvm/lib/IR/ConstantsContext.h (+2-1)
  • (modified) llvm/lib/IR/Core.cpp (+3-1)
  • (modified) llvm/lib/SandboxIR/Constant.cpp (+9-2)
  • (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+31-6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+2-2)
  • (modified) llvm/lib/Transforms/Utils/ValueMapper.cpp (+3-2)
  • (modified) llvm/unittests/SandboxIR/SandboxIRTest.cpp (+1-1)
  • (modified) llvm/unittests/Transforms/Utils/ValueMapperTest.cpp (+9-4)
diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp
index 4b032306ead72..2d72fef470af6 100644
--- a/clang/lib/CodeGen/CGPointerAuth.cpp
+++ b/clang/lib/CodeGen/CGPointerAuth.cpp
@@ -308,9 +308,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key,
     IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0);
   }
 
-  return llvm::ConstantPtrAuth::get(Pointer,
-                                    llvm::ConstantInt::get(Int32Ty, Key),
-                                    IntegerDiscriminator, AddressDiscriminator);
+  return llvm::ConstantPtrAuth::get(
+      Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator,
+      AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy));
 }
 
 /// Does a given PointerAuthScheme require us to sign a value
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index ec2535ac85966..13521ba6cd00f 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -431,6 +431,7 @@ enum ConstantsCodes {
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
   CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
   CST_CODE_PTRAUTH = 33,              // [ptr, key, disc, addrdisc]
+  CST_CODE_PTRAUTH2 = 34,             // [ptr, key, disc, addrdisc, DeactivationSymbol]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index a50217078d0ed..45d5352bf06a6 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1022,10 +1022,10 @@ class ConstantPtrAuth final : public Constant {
   friend struct ConstantPtrAuthKeyType;
   friend class Constant;
 
-  constexpr static IntrusiveOperandsAllocMarker AllocMarker{4};
+  constexpr static IntrusiveOperandsAllocMarker AllocMarker{5};
 
   ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc,
-                  Constant *AddrDisc);
+                  Constant *AddrDisc, Constant *DeactivationSymbol);
 
   void *operator new(size_t s) { return User::operator new(s, AllocMarker); }
 
@@ -1035,7 +1035,8 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer signed with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              ConstantInt *Disc, Constant *AddrDisc);
+                              ConstantInt *Disc, Constant *AddrDisc,
+                              Constant *DeactivationSymbol);
 
   /// Produce a new ptrauth expression signing the given value using
   /// the same schema as is stored in one.
@@ -1067,6 +1068,10 @@ class ConstantPtrAuth final : public Constant {
     return !getAddrDiscriminator()->isNullValue();
   }
 
+  Constant *getDeactivationSymbol() const {
+    return cast<Constant>(Op<4>().get());
+  }
+
   /// A constant value for the address discriminator which has special
   /// significance to ctors/dtors lowering. Regular address discrimination can't
   /// be applied for them since uses of llvm.global_{c|d}tors are disallowed
@@ -1094,7 +1099,7 @@ class ConstantPtrAuth final : public Constant {
 
 template <>
 struct OperandTraits<ConstantPtrAuth>
-    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+    : public FixedNumOperandTraits<ConstantPtrAuth, 5> {};
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
 
diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h
index 17f55e973cd76..5243a9476ac64 100644
--- a/llvm/include/llvm/SandboxIR/Constant.h
+++ b/llvm/include/llvm/SandboxIR/Constant.h
@@ -1096,7 +1096,8 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer signed with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              ConstantInt *Disc, Constant *AddrDisc);
+                              ConstantInt *Disc, Constant *AddrDisc,
+                              Constant *DeactivationSymbol);
   /// The pointer that is signed in this ptrauth signed pointer.
   Constant *getPointer() const;
 
@@ -1111,6 +1112,8 @@ class ConstantPtrAuth final : public Constant {
   /// the only global-initializer user of the ptrauth signed pointer.
   Constant *getAddrDiscriminator() const;
 
+  Constant *getDeactivationSymbol() const;
+
   /// Whether there is any non-null address discriminator.
   bool hasAddressDiscriminator() const {
     return cast<llvm::ConstantPtrAuth>(Val)->hasAddressDiscriminator();
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 960119bab0933..dfa014aa0bd7d 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4226,11 +4226,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   }
   case lltok::kw_ptrauth: {
     // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
-    //                         (',' i64 <disc> (',' ptr addrdisc)? )? ')'
+    //                         (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)? )? )? ')'
     Lex.Lex();
 
     Constant *Ptr, *Key;
-    Constant *Disc = nullptr, *AddrDisc = nullptr;
+    Constant *Disc = nullptr, *AddrDisc = nullptr,
+             *DeactivationSymbol = nullptr;
 
     if (parseToken(lltok::lparen,
                    "expected '(' in constant ptrauth expression") ||
@@ -4239,11 +4240,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
                    "expected comma in constant ptrauth expression") ||
         parseGlobalTypeAndValue(Key))
       return true;
-    // If present, parse the optional disc/addrdisc.
-    if (EatIfPresent(lltok::comma))
-      if (parseGlobalTypeAndValue(Disc) ||
-          (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)))
-        return true;
+    // If present, parse the optional disc/addrdisc/ds.
+    if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc))
+      return true;
+    if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))
+      return true;
+    if (EatIfPresent(lltok::comma) &&
+        parseGlobalTypeAndValue(DeactivationSymbol))
+      return true;
     if (parseToken(lltok::rparen,
                    "expected ')' in constant ptrauth expression"))
       return true;
@@ -4274,7 +4278,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
       AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
     }
 
-    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc);
+    if (DeactivationSymbol) {
+      if (!DeactivationSymbol->getType()->isPointerTy())
+        return error(
+            ID.Loc, "constant ptrauth deactivation symbol must be a pointer");
+    } else {
+      DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0));
+    }
+
+    ID.ConstantVal =
+        ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol);
     ID.Kind = ValID::t_Constant;
     return false;
   }
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 40e755902b724..c09c3b4f7d38c 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1611,7 +1611,13 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
           if (!Disc)
             return error("ptrauth disc operand must be ConstantInt");
 
-          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
+          auto *DeactivationSymbol =
+              ConstOps.size() > 4 ? ConstOps[4]
+                                  : ConstantPointerNull::get(cast<PointerType>(
+                                        ConstOps[3]->getType()));
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3],
+                                   DeactivationSymbol);
           break;
         }
         case BitcodeConstant::NoCFIOpcode: {
@@ -3811,6 +3817,16 @@ Error BitcodeReader::parseConstants() {
                                    (unsigned)Record[2], (unsigned)Record[3]});
       break;
     }
+    case bitc::CST_CODE_PTRAUTH2: {
+      if (Record.size() < 4)
+        return error("Invalid ptrauth record");
+      // Ptr, Key, Disc, AddrDisc, DeactivationSymbol
+      V = BitcodeConstant::create(
+          Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
+          {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2],
+           (unsigned)Record[3], (unsigned)Record[4]});
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 79547b299a903..5efb321967008 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1630,12 +1630,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
   if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
     Out << "ptrauth (";
 
-    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)
+    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?)
     unsigned NumOpsToWrite = 2;
     if (!CPA->getOperand(2)->isNullValue())
       NumOpsToWrite = 3;
     if (!CPA->getOperand(3)->isNullValue())
       NumOpsToWrite = 4;
+    if (!CPA->getOperand(4)->isNullValue())
+      NumOpsToWrite = 5;
 
     ListSeparator LS;
     for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) {
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index fb659450bfeeb..007d36d19f373 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2072,19 +2072,22 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
 //
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      ConstantInt *Disc, Constant *AddrDisc) {
-  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc};
+                                      ConstantInt *Disc, Constant *AddrDisc,
+                                      Constant *DeactivationSymbol) {
+  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol};
   ConstantPtrAuthKeyType MapKey(ArgVec);
   LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
   return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
 }
 
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
-  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator());
+  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(),
+             getDeactivationSymbol());
 }
 
 ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
-                                 ConstantInt *Disc, Constant *AddrDisc)
+                                 ConstantInt *Disc, Constant *AddrDisc,
+                                 Constant *DeactivationSymbol)
     : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) {
   assert(Ptr->getType()->isPointerTy());
   assert(Key->getBitWidth() == 32);
@@ -2094,6 +2097,7 @@ ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
   setOperand(1, Key);
   setOperand(2, Disc);
   setOperand(3, AddrDisc);
+  setOperand(4, DeactivationSymbol);
 }
 
 /// Remove the constant from the constant table.
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index e5c9622e09927..bf9d8ab952271 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -545,7 +545,8 @@ struct ConstantPtrAuthKeyType {
 
   ConstantPtrAuth *create(TypeClass *Ty) const {
     return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
-                               cast<ConstantInt>(Operands[2]), Operands[3]);
+                               cast<ConstantInt>(Operands[2]), Operands[3],
+                               Operands[4]);
   }
 };
 
diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index f4b03e8cb8aa3..6190ebdac16d4 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -1687,7 +1687,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key,
                                  LLVMValueRef Disc, LLVMValueRef AddrDisc) {
   return wrap(ConstantPtrAuth::get(
       unwrap<Constant>(Ptr), unwrap<ConstantInt>(Key),
-      unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc)));
+      unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc),
+      ConstantPointerNull::get(
+          cast<PointerType>(unwrap<Constant>(AddrDisc)->getType()))));
 }
 
 /*-- Opcode mapping */
diff --git a/llvm/lib/SandboxIR/Constant.cpp b/llvm/lib/SandboxIR/Constant.cpp
index 3e13c935c4281..0a28cf9feeb4d 100644
--- a/llvm/lib/SandboxIR/Constant.cpp
+++ b/llvm/lib/SandboxIR/Constant.cpp
@@ -421,10 +421,12 @@ PointerType *NoCFIValue::getType() const {
 }
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      ConstantInt *Disc, Constant *AddrDisc) {
+                                      ConstantInt *Disc, Constant *AddrDisc,
+                                      Constant *DeactivationSymbol) {
   auto *LLVMC = llvm::ConstantPtrAuth::get(
       cast<llvm::Constant>(Ptr->Val), cast<llvm::ConstantInt>(Key->Val),
-      cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val));
+      cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val),
+      cast<llvm::Constant>(DeactivationSymbol->Val));
   return cast<ConstantPtrAuth>(Ptr->getContext().getOrCreateConstant(LLVMC));
 }
 
@@ -448,6 +450,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const {
       cast<llvm::ConstantPtrAuth>(Val)->getAddrDiscriminator());
 }
 
+Constant *ConstantPtrAuth::getDeactivationSymbol() const {
+  return Ctx.getOrCreateConstant(
+      cast<llvm::ConstantPtrAuth>(Val)->getDeactivationSymbol());
+}
+
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
   auto *LLVMC = cast<llvm::ConstantPtrAuth>(Val)->getWithSameSchema(
       cast<llvm::Constant>(Pointer->Val));
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 135f6cff0f78b..283493408699e 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -195,7 +195,7 @@ class AArch64AsmPrinter : public AsmPrinter {
 
   const MCExpr *emitPAuthRelocationAsIRelative(
       const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
-      bool HasAddressDiversity, bool IsDSOLocal);
+      bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr);
 
   /// tblgen'erated driver function for lowering simple MI->MC
   /// pseudo instructions.
@@ -2270,15 +2270,17 @@ static void emitAddress(MCStreamer &Streamer, MCRegister Reg,
 }
 
 static bool targetSupportsPAuthRelocation(const Triple &TT,
-                                          const MCExpr *Target) {
+                                          const MCExpr *Target,
+                                          const MCExpr *DSExpr) {
   // No released version of glibc supports PAuth relocations.
   if (TT.isOSGlibc())
     return false;
 
   // We emit PAuth constants as IRELATIVE relocations in cases where the
   // constant cannot be represented as a PAuth relocation:
-  // 1) The signed value is not a symbol.
-  return !isa<MCConstantExpr>(Target);
+  // 1) There is a deactivation symbol.
+  // 2) The signed value is not a symbol.
+  return !DSExpr && !isa<MCConstantExpr>(Target);
 }
 
 static bool targetSupportsIRelativeRelocation(const Triple &TT) {
@@ -2295,7 +2297,7 @@ static bool targetSupportsIRelativeRelocation(const Triple &TT) {
 
 const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
     const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
-    bool HasAddressDiversity, bool IsDSOLocal) {
+    bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) {
   const Triple &TT = TM.getTargetTriple();
 
   // We only emit an IRELATIVE relocation if the target supports IRELATIVE and
@@ -2358,6 +2360,18 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
       MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext());
   OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef),
                                *STI);
+
+  if (DSExpr) {
+    auto *PrePACInstExpr =
+        MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext());
+    OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_INST32", DSExpr,
+                                    SMLoc(), *STI);
+  }
+
+  // We need a RET despite the above tail call because the deactivation symbol
+  // may replace it with a NOP.
+  OutStreamer->emitInstruction(MCInstBuilder(AArch64::RET).addReg(AArch64::LR),
+                               *STI);
   OutStreamer->popSection();
 
   return MCSymbolRefExpr::create(IFuncSym, OutStreamer->getContext());
@@ -2388,6 +2402,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
     Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx);
   }
 
+  const MCExpr *DSExpr = nullptr;
+  if (auto *DS = dyn_cast<GlobalValue>(CPA.getDeactivationSymbol())) {
+    if (isa<GlobalAlias>(DS))
+      return Sym;
+    DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx);
+  }
+
   uint64_t KeyID = CPA.getKey()->getZExtValue();
   // We later rely on valid KeyID value in AArch64PACKeyIDToString call from
   // AArch64AuthMCExpr::printImpl, so fail fast.
@@ -2404,9 +2425,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
   // Check if we need to represent this with an IRELATIVE and emit it if so.
   if (auto *IFuncSym = emitPAuthRelocationAsIRelative(
           Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(),
-          BaseGVB && BaseGVB->isDSOLocal()))
+          BaseGVB && BaseGVB->isDSOLocal(), DSExpr))
     return IFuncSym;
 
+  if (DSExpr)
+    report_fatal_error("deactivation symbols unsupported in constant "
+                       "expressions on this target");
+
   // Finally build the complete @AUTH expr.
   return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID),
                                    CPA.hasAddressDiscriminator(), Ctx);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 12dd4cec85f59..58b98d8d93464 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2946,9 +2946,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
         auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
         auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
-        auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
+        auto *Null = ConstantPointerNull::get(Builder.getPtrTy());
         auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
-                                            SignDisc, SignAddrDisc);
+                                            SignDisc, Null, Null);
         replaceInstUsesWith(
             *II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
         return eraseInstFromFunction(*II);
diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp
index 5e50536a99206..320bef6c8f240 100644
--- a/llvm/lib/Transforms/Utils/ValueMapper.cpp
+++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp
@@ -526,8 +526,9 @@ Value *Mapper::mapValue(const Value *V) {
   if (isa<ConstantVector>(C))
     return getVM()[V] = ConstantVector::get(Ops);
   if (isa<ConstantPtrAuth>(C))
-    return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
-                                             cast<ConstantInt>(Ops[2]), Ops[3]);
+    return getVM()[V] =
+               ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
+                                    cast<ConstantInt>(Ops[2]), Ops[3], Ops[4]);
   // If this is a no-operand constant, it must be because the type was remapped.
   if (isa<PoisonValue>(C))
     return getVM()[V] = PoisonValue::get(NewTy);
diff --git a/llvm/unittests/SandboxIR/SandboxIR...
[truncated]

@github-actions
Copy link

github-actions bot commented Mar 28, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions h,cpp -- clang/lib/CodeGen/CGPointerAuth.cpp llvm/include/llvm/Bitcode/LLVMBitCodes.h llvm/include/llvm/IR/Constants.h llvm/include/llvm/SandboxIR/Constant.h llvm/lib/AsmParser/LLParser.cpp llvm/lib/Bitcode/Reader/BitcodeReader.cpp llvm/lib/Bitcode/Writer/BitcodeWriter.cpp llvm/lib/IR/AsmWriter.cpp llvm/lib/IR/Constants.cpp llvm/lib/IR/ConstantsContext.h llvm/lib/IR/Core.cpp llvm/lib/IR/Verifier.cpp llvm/lib/SandboxIR/Constant.cpp llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp llvm/lib/Transforms/Utils/ValueMapper.cpp llvm/unittests/SandboxIR/SandboxIRTest.cpp llvm/unittests/Transforms/Utils/ValueMapperTest.cpp

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 4cdf192fc..7dc869e44 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4228,7 +4228,8 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   }
   case lltok::kw_ptrauth: {
     // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
-    //                         (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)? )? )? ')'
+    //                         (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)?
+    //                         )? )? ')'
     Lex.Lex();
 
     Constant *Ptr, *Key;
@@ -4282,10 +4283,11 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
 
     if (DeactivationSymbol) {
       if (!DeactivationSymbol->getType()->isPointerTy())
-        return error(
-            ID.Loc, "constant ptrauth deactivation symbol must be a pointer");
+        return error(ID.Loc,
+                     "constant ptrauth deactivation symbol must be a pointer");
     } else {
-      DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0));
+      DeactivationSymbol =
+          ConstantPointerNull::get(PointerType::get(Context, 0));
     }
 
     ID.ConstantVal =
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 6f4a99b6a..6fa0c5861 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2135,8 +2135,8 @@ bool ConstantPtrAuth::hasSpecialAddressDiscriminator(uint64_t Value) const {
 bool ConstantPtrAuth::isKnownCompatibleWith(const Value *Key,
                                             const Value *Discriminator,
                                             const DataLayout &DL) const {
-  // This function may only be validly called to analyze a ptrauth operation with
-  // no deactivation symbol, so if we have one it isn't compatible.
+  // This function may only be validly called to analyze a ptrauth operation
+  // with no deactivation symbol, so if we have one it isn't compatible.
   if (!getDeactivationSymbol()->isNullValue())
     return false;
 

pcc added a commit to pcc/llvm-project that referenced this pull request Apr 3, 2025
Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:
- Fix broken test.
- Add bitcode and IR writer support.
- Add tests.

Pull Request: llvm#133537
pcc added a commit to pcc/llvm-project that referenced this pull request Apr 4, 2025
Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:
- Fix broken test.
- Add bitcode and IR writer support.
- Add tests.

Pull Request: llvm#133537
pcc added 2 commits May 12, 2025 21:38
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing verifier checks?

unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc)));
unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc),
ConstantPointerNull::get(
cast<PointerType>(unwrap<Constant>(AddrDisc)->getType()))));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to extend the C API to give access to this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reckon that could be done in a followup if anyone needs it.

@pcc
Copy link
Contributor Author

pcc commented May 28, 2025

Missing verifier checks?

Right, I guess the new operand can either be null (no deactivation symbol) or a globalvariable.

pcc added 2 commits July 8, 2025 21:38
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
pcc added a commit to pcc/llvm-project that referenced this pull request Aug 1, 2025
Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

Pull Request: llvm#133537
pcc added 2 commits August 1, 2025 14:04
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
@pcc
Copy link
Contributor Author

pcc commented Aug 2, 2025

Right, I guess the new operand can either be null (no deactivation symbol) or a globalvariable.

It also can be a global value (to cover the case where it points to a defined alias). Added the verifier check.

"signed ptrauth constant discriminator must be i64 constant integer");

Check(isa<GlobalValue>(CPA->getDeactivationSymbol()) ||
CPA->getDeactivationSymbol()->isNullValue(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also check CPA->getDeactivationSymbol()->getType()->isPointerTy()?

Copy link
Contributor Author

@pcc pcc Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's probably best to check that in the ConstantPtrAuth constructor; I've added a check for that there.

I had already added a check for this operand being a pointer in the .ll parser; I noticed that there was no test coverage for that so I added it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For things we can easily check in the verifier, I prefer to check them even if there's also an assertion, to catch issues in assert-disabled builds.

Also, missing check in the bitcode parser.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For things we can easily check in the verifier, I prefer to check them even if there's also an assertion, to catch issues in assert-disabled builds.

Isn't the verifier disabled by default in no-asserts builds? So I guess it wouldn't make much of a difference.

CmdArgs.push_back("-disable-llvm-verifier");

(Looks like the comments were meant to read "no-asserts", I sent #157769 for that.)

That being said I don't feel strongly so I added it.

Also, missing check in the bitcode parser.

Added.

Created using spr 1.3.6-beta.1
pcc added 4 commits September 5, 2025 16:34
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing LangRef update.

Otherwise looks fine, but probably someone actively doing ptrauth stuff should also look.

Created using spr 1.3.6-beta.1
@pcc
Copy link
Contributor Author

pcc commented Sep 11, 2025

Missing LangRef update.

Done

@ojhunt
Copy link
Contributor

ojhunt commented Sep 11, 2025

I think this is better tied in to the options parameter, which means I should actually get that landed.

We'd probably want to extend that to include disabled_symbol=<>

Copy link
Contributor

@ojhunt ojhunt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned about this - I initially thought this was for the purpose of the structure field protection in the frontend, but this is modifying the actual pointer auth intrinsics in the backend which is very concerning given the work we need to do to merging and protection etc.

Record.push_back(VE.getValueID(NC->getGlobalValue()));
} else if (const auto *CPA = dyn_cast<ConstantPtrAuth>(C)) {
Code = bitc::CST_CODE_PTRAUTH;
Code = bitc::CST_CODE_PTRAUTH2;
Copy link
Contributor

@ojhunt ojhunt Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be

Constant *DeactivationSymbol = CPA->getDeactivationSymbol();
Code =  DeactivationSymbol->isNullValue() ? CST_CODE_PTRAUTH : CST_CODE_PTRAUTH2;
...
if (!DeactivationSymbol->isNullValue())
   Record.push_back(VE.getValueID(DeactivationSymbol));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to correct the of the Code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bitcode doesn't have forwards compatibility, only backwards compatibility. So I think it's fine to use PTRAUTH2 unconditionally here.

return error("ptrauth disc operand must be ConstantInt");

C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
auto *DeactivationSymbol =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] while there is a type named on the rhs, I think based on the ternary rather than single option, this should be Constant * instead of auto *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Created using spr 1.3.6-beta.1
Copy link
Contributor Author

@pcc pcc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned about this - I initially thought this was for the purpose of the structure field protection in the frontend, but this is modifying the actual pointer auth intrinsics in the backend which is very concerning given the work we need to do to merging and protection etc.

Not sure I understand your concerns, can you be more specific?

Record.push_back(VE.getValueID(NC->getGlobalValue()));
} else if (const auto *CPA = dyn_cast<ConstantPtrAuth>(C)) {
Code = bitc::CST_CODE_PTRAUTH;
Code = bitc::CST_CODE_PTRAUTH2;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bitcode doesn't have forwards compatibility, only backwards compatibility. So I think it's fine to use PTRAUTH2 unconditionally here.

return error("ptrauth disc operand must be ConstantInt");

C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
auto *DeactivationSymbol =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@ojhunt
Copy link
Contributor

ojhunt commented Sep 11, 2025

I'm concerned about this - I initially thought this was for the purpose of the structure field protection in the frontend, but this is modifying the actual pointer auth intrinsics in the backend which is very concerning given the work we need to do to merging and protection etc.

Not sure I understand your concerns, can you be more specific?

I'm concerned about the interaction of these changes with ptrauth intrinsic optimizations, and the ability for attackers to gain control of the enablement flags.

But that said, this is a backend change so @ahmedbougacha should be the main reviewer.

@ojhunt
Copy link
Contributor

ojhunt commented Sep 11, 2025

(edit: I misunderstood Ahmed's opinion. Will check with him to clarify)

Created using spr 1.3.6-beta.1
@pcc
Copy link
Contributor Author

pcc commented Sep 11, 2025

I have checked in with @ahmedbougacha and his feeling is that this is fine as it requires a bunch of work to opt in, and for places where the security is important enough that we don't want people using this it's easy enough to block.

Thanks for checking.

I'm concerned about the interaction of these changes with ptrauth intrinsic optimizations

I took a look and found some cases where we needed to inhibit optimizations. There was no practical effect due to how PFP uses these intrinisics, but I implemented the inhibitions in #133536 and this PR.

the ability for attackers to gain control of the enablement flags.

This isn't possible, the symbols are resolved at static link time. See the RFC for more information: https://discourse.llvm.org/t/rfc-deactivation-symbols/85556

@ojhunt
Copy link
Contributor

ojhunt commented Sep 11, 2025

This isn't possible, the symbols are resolved at static link time. See the RFC for more information: https://discourse.llvm.org/t/rfc-deactivation-symbols/85556

Oh wait, I have completely misunderstood that - I have always assumed dynamic link and that's the reason for a bunch of the concerns I raised, that I now assume sounded really weird :D

@ojhunt
Copy link
Contributor

ojhunt commented Sep 12, 2025

I have checked in with @ahmedbougacha and his feeling is that this is fine as it requires a bunch of work to opt in, and for places where the security is important enough that we don't want people using this it's easy enough to block.

Thanks for checking.

as above I misunderstood what Ahmed was saying, and also the wording was terrible: the opinion on disabling and similar was mine - the concerns there were mine and I was trying to say I felt my concerns had been addressed.

@ahmedbougacha
Copy link
Member

Yep, this does seem reasonable to me as well (with a question in-line).
Thanks for the summons, sorry I haven't had the chance to take a look before!

break;
}
case bitc::CST_CODE_PTRAUTH2: {
if (Record.size() < 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be 5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, fixed

LLVM_ABI static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc);
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to do this here, but we probably should make the optional operands (in textual IR) optional here as well, and implicitly make them null? Now that I think about it, I'm not sure how idiomatic that would be

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered it, but since there are only a few places that need this, it seemed slightly better to be explicit about all the operands since that's consistent with what we have elsewhere.

What might be nice is if we initialized this using fields of a passed-in struct so that call sites are more readable, but that's a separate change.

pcc added 2 commits September 11, 2025 18:35
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
@pcc pcc requested a review from fmayer October 23, 2025 00:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:AArch64 clang:codegen IR generation bugs: mangling, exceptions, etc. clang Clang issues not falling into any other category llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:SandboxIR llvm:transforms

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

5 participants