diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp index 375f87a68ee6d..326ffcc6512be 100644 --- a/clang/lib/CodeGen/CGPointerAuth.cpp +++ b/clang/lib/CodeGen/CGPointerAuth.cpp @@ -440,9 +440,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key, IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0); } - return llvm::ConstantPtrAuth::get(Pointer, - llvm::ConstantInt::get(Int32Ty, Key), - IntegerDiscriminator, AddressDiscriminator); + return llvm::ConstantPtrAuth::get( + Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator, + AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy)); } /// Does a given PointerAuthScheme require us to sign a value diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst index f2cb85b291c61..4881ff59bd9ac 100644 --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -3199,6 +3199,8 @@ A "convergencectrl" operand bundle is only valid on a ``convergent`` operation. When present, the operand bundle must contain exactly one value of token type. See the :doc:`ConvergentOperations` document for details. +.. _deactivationsymbol: + Deactivation Symbol Operand Bundles ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -5265,7 +5267,7 @@ need to refer to the actual function body. Pointer Authentication Constants -------------------------------- -``ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)`` +``ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?)`` A '``ptrauth``' constant represents a pointer with a cryptographic authentication signature embedded into some bits, as described in the @@ -5294,6 +5296,11 @@ Otherwise, the expression is equivalent to: %tmp2 = call i64 @llvm.ptrauth.sign(i64 ptrtoint (ptr CST to i64), i32 KEY, i64 %tmp1) %val = inttoptr i64 %tmp2 to ptr +If the deactivation symbol operand ``DS`` has a non-null value, +the semantics are as if a :ref:`deactivation-symbol operand bundle +` were added to the ``llvm.ptrauth.sign`` intrinsic +calls above, with ``DS`` as the only operand. + .. _constantexprs: Constant Expressions diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h index 464f475098ec5..5df2f3c4f4a9c 100644 --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -437,6 +437,8 @@ enum ConstantsCodes { CST_CODE_CE_GEP_WITH_INRANGE = 31, // [opty, flags, range, n x operands] CST_CODE_CE_GEP = 32, // [opty, flags, n x operands] CST_CODE_PTRAUTH = 33, // [ptr, key, disc, addrdisc] + CST_CODE_PTRAUTH2 = 34, // [ptr, key, disc, addrdisc, + // deactivation_symbol] }; /// CastOpcodes - These are values used in the bitcode files to encode which diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h index e06e6adbc3130..e3f2eb9fa44b8 100644 --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -1033,10 +1033,10 @@ class ConstantPtrAuth final : public Constant { friend struct ConstantPtrAuthKeyType; friend class Constant; - constexpr static IntrusiveOperandsAllocMarker AllocMarker{4}; + constexpr static IntrusiveOperandsAllocMarker AllocMarker{5}; ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc, - Constant *AddrDisc); + Constant *AddrDisc, Constant *DeactivationSymbol); void *operator new(size_t s) { return User::operator new(s, AllocMarker); } @@ -1046,7 +1046,8 @@ class ConstantPtrAuth final : public Constant { public: /// Return a pointer signed with the specified parameters. LLVM_ABI static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc); + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol); /// Produce a new ptrauth expression signing the given value using /// the same schema as is stored in one. @@ -1078,6 +1079,10 @@ class ConstantPtrAuth final : public Constant { return !getAddrDiscriminator()->isNullValue(); } + Constant *getDeactivationSymbol() const { + return cast(Op<4>().get()); + } + /// A constant value for the address discriminator which has special /// significance to ctors/dtors lowering. Regular address discrimination can't /// be applied for them since uses of llvm.global_{c|d}tors are disallowed @@ -1106,7 +1111,7 @@ class ConstantPtrAuth final : public Constant { template <> struct OperandTraits - : public FixedNumOperandTraits {}; + : public FixedNumOperandTraits {}; DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant) diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h index 6f682a7059d10..2fe923f6c3866 100644 --- a/llvm/include/llvm/SandboxIR/Constant.h +++ b/llvm/include/llvm/SandboxIR/Constant.h @@ -1363,7 +1363,8 @@ class ConstantPtrAuth final : public Constant { public: /// Return a pointer signed with the specified parameters. LLVM_ABI static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc); + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol); /// The pointer that is signed in this ptrauth signed pointer. LLVM_ABI Constant *getPointer() const; @@ -1378,6 +1379,8 @@ class ConstantPtrAuth final : public Constant { /// the only global-initializer user of the ptrauth signed pointer. LLVM_ABI Constant *getAddrDiscriminator() const; + Constant *getDeactivationSymbol() const; + /// Whether there is any non-null address discriminator. bool hasAddressDiscriminator() const { return cast(Val)->hasAddressDiscriminator(); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 55899660fa84a..4cdf192fcc5cd 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -4228,11 +4228,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { } case lltok::kw_ptrauth: { // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 - // (',' i64 (',' ptr addrdisc)? )? ')' + // (',' i64 (',' ptr addrdisc (',' ptr ds)? )? )? ')' Lex.Lex(); Constant *Ptr, *Key; - Constant *Disc = nullptr, *AddrDisc = nullptr; + Constant *Disc = nullptr, *AddrDisc = nullptr, + *DeactivationSymbol = nullptr; if (parseToken(lltok::lparen, "expected '(' in constant ptrauth expression") || @@ -4241,11 +4242,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { "expected comma in constant ptrauth expression") || parseGlobalTypeAndValue(Key)) return true; - // If present, parse the optional disc/addrdisc. - if (EatIfPresent(lltok::comma)) - if (parseGlobalTypeAndValue(Disc) || - (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))) - return true; + // If present, parse the optional disc/addrdisc/ds. + if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc)) + return true; + if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)) + return true; + if (EatIfPresent(lltok::comma) && + parseGlobalTypeAndValue(DeactivationSymbol)) + return true; if (parseToken(lltok::rparen, "expected ')' in constant ptrauth expression")) return true; @@ -4276,7 +4280,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0)); } - ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc); + if (DeactivationSymbol) { + if (!DeactivationSymbol->getType()->isPointerTy()) + return error( + ID.Loc, "constant ptrauth deactivation symbol must be a pointer"); + } else { + DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0)); + } + + ID.ConstantVal = + ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol); ID.Kind = ValID::t_Constant; return false; } diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp index aaee1f0a7687c..adfd37bdc66d4 100644 --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1609,7 +1609,16 @@ Expected BitcodeReader::materializeValue(unsigned StartValID, if (!Disc) return error("ptrauth disc operand must be ConstantInt"); - C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]); + Constant *DeactivationSymbol = + ConstOps.size() > 4 ? ConstOps[4] + : ConstantPointerNull::get(cast( + ConstOps[3]->getType())); + if (!DeactivationSymbol->getType()->isPointerTy()) + return error( + "ptrauth deactivation symbol operand must be a pointer"); + + C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3], + DeactivationSymbol); break; } case BitcodeConstant::NoCFIOpcode: { @@ -3811,6 +3820,16 @@ Error BitcodeReader::parseConstants() { (unsigned)Record[2], (unsigned)Record[3]}); break; } + case bitc::CST_CODE_PTRAUTH2: { + if (Record.size() < 5) + return error("Invalid ptrauth record"); + // Ptr, Key, Disc, AddrDisc, DeactivationSymbol + V = BitcodeConstant::create( + Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode, + {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2], + (unsigned)Record[3], (unsigned)Record[4]}); + break; + } } assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID"); diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp index 7ed140d392fca..4e4e8773aa6e3 100644 --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -3013,11 +3013,12 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal, Record.push_back(VE.getTypeID(NC->getGlobalValue()->getType())); Record.push_back(VE.getValueID(NC->getGlobalValue())); } else if (const auto *CPA = dyn_cast(C)) { - Code = bitc::CST_CODE_PTRAUTH; + Code = bitc::CST_CODE_PTRAUTH2; Record.push_back(VE.getValueID(CPA->getPointer())); Record.push_back(VE.getValueID(CPA->getKey())); Record.push_back(VE.getValueID(CPA->getDiscriminator())); Record.push_back(VE.getValueID(CPA->getAddrDiscriminator())); + Record.push_back(VE.getValueID(CPA->getDeactivationSymbol())); } else { #ifndef NDEBUG C->dump(); diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index ae086bcd3902d..0fc146432bd2f 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -1673,12 +1673,14 @@ static void writeConstantInternal(raw_ostream &Out, const Constant *CV, if (const auto *CPA = dyn_cast(CV)) { Out << "ptrauth ("; - // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?) + // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?) unsigned NumOpsToWrite = 2; if (!CPA->getOperand(2)->isNullValue()) NumOpsToWrite = 3; if (!CPA->getOperand(3)->isNullValue()) NumOpsToWrite = 4; + if (!CPA->getOperand(4)->isNullValue()) + NumOpsToWrite = 5; ListSeparator LS; for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) { diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index 2c2950c70d346..6f4a99b6a62f2 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2061,28 +2061,33 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) { // ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) { - Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc}; + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) { + Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol}; ConstantPtrAuthKeyType MapKey(ArgVec); LLVMContextImpl *pImpl = Ptr->getContext().pImpl; return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey); } ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const { - return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator()); + return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(), + getDeactivationSymbol()); } ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) { assert(Ptr->getType()->isPointerTy()); assert(Key->getBitWidth() == 32); assert(Disc->getBitWidth() == 64); assert(AddrDisc->getType()->isPointerTy()); + assert(DeactivationSymbol->getType()->isPointerTy()); setOperand(0, Ptr); setOperand(1, Key); setOperand(2, Disc); setOperand(3, AddrDisc); + setOperand(4, DeactivationSymbol); } /// Remove the constant from the constant table. @@ -2130,6 +2135,11 @@ bool ConstantPtrAuth::hasSpecialAddressDiscriminator(uint64_t Value) const { bool ConstantPtrAuth::isKnownCompatibleWith(const Value *Key, const Value *Discriminator, const DataLayout &DL) const { + // This function may only be validly called to analyze a ptrauth operation with + // no deactivation symbol, so if we have one it isn't compatible. + if (!getDeactivationSymbol()->isNullValue()) + return false; + // If the keys are different, there's no chance for this to be compatible. if (getKey() != Key) return false; diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h index 51fb40bad201d..584391b53800a 100644 --- a/llvm/lib/IR/ConstantsContext.h +++ b/llvm/lib/IR/ConstantsContext.h @@ -539,7 +539,8 @@ struct ConstantPtrAuthKeyType { ConstantPtrAuth *create(TypeClass *Ty) const { return new ConstantPtrAuth(Operands[0], cast(Operands[1]), - cast(Operands[2]), Operands[3]); + cast(Operands[2]), Operands[3], + Operands[4]); } }; diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp index 3f1cc1e4e6d0e..efd4e5375f6d2 100644 --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -1699,7 +1699,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key, LLVMValueRef Disc, LLVMValueRef AddrDisc) { return wrap(ConstantPtrAuth::get( unwrap(Ptr), unwrap(Key), - unwrap(Disc), unwrap(AddrDisc))); + unwrap(Disc), unwrap(AddrDisc), + ConstantPointerNull::get( + cast(unwrap(AddrDisc)->getType())))); } /*-- Opcode mapping */ diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index c9ff86b7df16b..26726ba721b01 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -2677,6 +2677,14 @@ void Verifier::visitConstantPtrAuth(const ConstantPtrAuth *CPA) { Check(CPA->getDiscriminator()->getBitWidth() == 64, "signed ptrauth constant discriminator must be i64 constant integer"); + + Check(CPA->getDeactivationSymbol()->getType()->isPointerTy(), + "signed ptrauth constant deactivation symbol must be a pointer"); + + Check(isa(CPA->getDeactivationSymbol()) || + CPA->getDeactivationSymbol()->isNullValue(), + "signed ptrauth constant deactivation symbol must be a global value " + "or null"); } bool Verifier::verifyAttributeCount(AttributeList Attrs, unsigned Params) { diff --git a/llvm/lib/SandboxIR/Constant.cpp b/llvm/lib/SandboxIR/Constant.cpp index 9de88ef2cf0a0..eb14797af081c 100644 --- a/llvm/lib/SandboxIR/Constant.cpp +++ b/llvm/lib/SandboxIR/Constant.cpp @@ -412,10 +412,12 @@ PointerType *NoCFIValue::getType() const { } ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) { + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) { auto *LLVMC = llvm::ConstantPtrAuth::get( cast(Ptr->Val), cast(Key->Val), - cast(Disc->Val), cast(AddrDisc->Val)); + cast(Disc->Val), cast(AddrDisc->Val), + cast(DeactivationSymbol->Val)); return cast(Ptr->getContext().getOrCreateConstant(LLVMC)); } @@ -439,6 +441,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const { cast(Val)->getAddrDiscriminator()); } +Constant *ConstantPtrAuth::getDeactivationSymbol() const { + return Ctx.getOrCreateConstant( + cast(Val)->getDeactivationSymbol()); +} + ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const { auto *LLVMC = cast(Val)->getWithSameSchema( cast(Pointer->Val)); diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index 353ddee4d8c88..e366c6389094a 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -222,7 +222,7 @@ class AArch64AsmPrinter : public AsmPrinter { const MCExpr *emitPAuthRelocationAsIRelative( const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID, - bool HasAddressDiversity, bool IsDSOLocal); + bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr); /// tblgen'erated driver function for lowering simple MI->MC /// pseudo instructions. @@ -2387,15 +2387,17 @@ static void emitAddress(MCStreamer &Streamer, MCRegister Reg, } static bool targetSupportsPAuthRelocation(const Triple &TT, - const MCExpr *Target) { + const MCExpr *Target, + const MCExpr *DSExpr) { // No released version of glibc supports PAuth relocations. if (TT.isOSGlibc()) return false; // We emit PAuth constants as IRELATIVE relocations in cases where the // constant cannot be represented as a PAuth relocation: - // 1) The signed value is not a symbol. - return !isa(Target); + // 1) There is a deactivation symbol. + // 2) The signed value is not a symbol. + return !DSExpr && !isa(Target); } static bool targetSupportsIRelativeRelocation(const Triple &TT) { @@ -2412,12 +2414,12 @@ static bool targetSupportsIRelativeRelocation(const Triple &TT) { const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative( const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID, - bool HasAddressDiversity, bool IsDSOLocal) { + bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) { const Triple &TT = TM.getTargetTriple(); // We only emit an IRELATIVE relocation if the target supports IRELATIVE and // does not support the kind of PAuth relocation that we are trying to emit. - if (targetSupportsPAuthRelocation(TT, Target) || + if (targetSupportsPAuthRelocation(TT, Target, DSExpr) || !targetSupportsIRelativeRelocation(TT)) return nullptr; @@ -2461,6 +2463,16 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative( emitMOVZ(AArch64::X1, Disc, 0); } + if (DSExpr) { + MCSymbol *PrePACInst = OutStreamer->getContext().createTempSymbol(); + OutStreamer->emitLabel(PrePACInst); + + auto *PrePACInstExpr = + MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext()); + OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_PATCHINST", + DSExpr, SMLoc()); + } + // We don't know the subtarget because this is being emitted for a global // initializer. Because the performance of IFUNC resolvers is unimportant, we // always call the EmuPAC runtime, which will end up using the PAC instruction @@ -2471,6 +2483,12 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative( MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext()); OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef), *STI); + + // We need a RET despite the above tail call because the deactivation symbol + // may replace it with a NOP. + if (DSExpr) + OutStreamer->emitInstruction( + MCInstBuilder(AArch64::RET).addReg(AArch64::LR), *STI); OutStreamer->popSection(); return MCSymbolRefExpr::create(IRelativeSym, AArch64::S_FUNCINIT, @@ -2502,6 +2520,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx); } + const MCExpr *DSExpr = nullptr; + if (auto *DS = dyn_cast(CPA.getDeactivationSymbol())) { + if (isa(DS)) + return Sym; + DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx); + } + uint64_t KeyID = CPA.getKey()->getZExtValue(); // We later rely on valid KeyID value in AArch64PACKeyIDToString call from // AArch64AuthMCExpr::printImpl, so fail fast. @@ -2522,9 +2547,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { // Check if we need to represent this with an IRELATIVE and emit it if so. if (auto *IFuncSym = emitPAuthRelocationAsIRelative( Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(), - BaseGVB && BaseGVB->isDSOLocal())) + BaseGVB && BaseGVB->isDSOLocal(), DSExpr)) return IFuncSym; + if (DSExpr) + report_fatal_error("deactivation symbols unsupported in constant " + "expressions on this target"); + // Finally build the complete @AUTH expr. return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(), Ctx); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 169c0966773c0..b02abe1824817 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3110,9 +3110,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (NeedSign && isa(II->getArgOperand(4))) { auto *SignKey = cast(II->getArgOperand(3)); auto *SignDisc = cast(II->getArgOperand(4)); - auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy()); + auto *Null = ConstantPointerNull::get(Builder.getPtrTy()); auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey, - SignDisc, SignAddrDisc); + SignDisc, Null, Null); replaceInstUsesWith( *II, ConstantExpr::getPointerCast(NewCPA, II->getType())); return eraseInstFromFunction(*II); diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp index 8d8a60b6918fe..fca17893f47f4 100644 --- a/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -526,8 +526,9 @@ Value *Mapper::mapValue(const Value *V) { if (isa(C)) return getVM()[V] = ConstantVector::get(Ops); if (isa(C)) - return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast(Ops[1]), - cast(Ops[2]), Ops[3]); + return getVM()[V] = + ConstantPtrAuth::get(Ops[0], cast(Ops[1]), + cast(Ops[2]), Ops[3], Ops[4]); // If this is a no-operand constant, it must be because the type was remapped. if (isa(C)) return getVM()[V] = PoisonValue::get(NewTy); diff --git a/llvm/test/Assembler/invalid-ptrauth-const6.ll b/llvm/test/Assembler/invalid-ptrauth-const6.ll new file mode 100644 index 0000000000000..6e8e1d386acc8 --- /dev/null +++ b/llvm/test/Assembler/invalid-ptrauth-const6.ll @@ -0,0 +1,6 @@ +; RUN: not llvm-as < %s 2>&1 | FileCheck %s + +@var = global i32 0 + +; CHECK: error: constant ptrauth deactivation symbol must be a pointer +@ptr = global ptr ptrauth (ptr @var, i32 0, i64 65535, ptr null, i64 0) diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll index e21786e5ee330..53cbe2d6ffd37 100644 --- a/llvm/test/Bitcode/compatibility.ll +++ b/llvm/test/Bitcode/compatibility.ll @@ -217,9 +217,13 @@ declare void @g.f1() ; CHECK: @g.sanitize_address_dyninit = global i32 0, sanitize_address_dyninit ; CHECK: @g.sanitize_multiple = global i32 0, sanitize_memtag, sanitize_address_dyninit +@ds = external global i32 + ; ptrauth constant @auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null) ; CHECK: @auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535) +@auth_var.ds = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null, ptr @ds) +; CHECK: @auth_var.ds = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null, ptr @ds) ;; Aliases ; Format: @ = [Linkage] [Visibility] [DLLStorageClass] [ThreadLocal] diff --git a/llvm/test/CodeGen/AArch64/ptrauth-irelative.ll b/llvm/test/CodeGen/AArch64/ptrauth-irelative.ll index 7857051668dfb..4ee1c19a86490 100644 --- a/llvm/test/CodeGen/AArch64/ptrauth-irelative.ll +++ b/llvm/test/CodeGen/AArch64/ptrauth-irelative.ll @@ -25,6 +25,23 @@ ; CHECK-NEXT: .xword [[FUNC]]@FUNCINIT @dsolocalref = constant ptr ptrauth (ptr @dsolocal, i32 2, i64 2, ptr null), align 8 +@ds = external global i8 + +; CHECK: dsolocalrefds: +; CHECK-NEXT: [[PLACE:.*]]: +; CHECK-NEXT: .section .text.startup +; CHECK-NEXT: [[FUNC:.*]]: +; CHECK-NEXT: adrp x0, dsolocal +; CHECK-NEXT: add x0, x0, :lo12:dsolocal +; CHECK-NEXT: mov x1, #2 +; CHECK-NEXT: [[LABEL:.L.*]]: +; CHECK-NEXT: .reloc [[LABEL]], R_AARCH64_PATCHINST, ds +; CHECK-NEXT: b __emupac_pacda +; CHECK-NEXT: ret +; CHECK-NEXT: .section .rodata +; CHECK-NEXT: .xword [[FUNC]]@FUNCINIT +@dsolocalrefds = constant ptr ptrauth (ptr @dsolocal, i32 2, i64 2, ptr null, ptr @ds), align 8 + ; CHECK: dsolocalref8: ; CHECK-NEXT: [[PLACE:.*]]: ; CHECK-NEXT: .section .text.startup diff --git a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll index 09d9649b09cc1..22c330fe7ae61 100644 --- a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll @@ -188,6 +188,15 @@ define i64 @test_ptrauth_nop_ds2(ptr %p) { ret i64 %authed } +define i64 @test_ptrauth_nop_ds_constant() { +; CHECK-LABEL: @test_ptrauth_nop_ds_constant( +; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr null, ptr @ds) to i64), i32 1, i64 1234) +; CHECK-NEXT: ret i64 [[AUTHED]] +; + %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr null, ptr @ds) to i64), i32 1, i64 1234) + ret i64 %authed +} + declare i64 @llvm.ptrauth.auth(i64, i32, i64) declare i64 @llvm.ptrauth.sign(i64, i32, i64) declare i64 @llvm.ptrauth.resign(i64, i32, i64, i32, i64) diff --git a/llvm/test/Verifier/ptrauth-constant.ll b/llvm/test/Verifier/ptrauth-constant.ll new file mode 100644 index 0000000000000..7a6d9d2634bc8 --- /dev/null +++ b/llvm/test/Verifier/ptrauth-constant.ll @@ -0,0 +1,6 @@ +; RUN: not opt -passes=verify < %s 2>&1 | FileCheck %s + +@g = external global i8 + +; CHECK: signed ptrauth constant deactivation symbol must be a global value or null +@ptr = global ptr ptrauth (ptr @g, i32 0, i64 65535, ptr null, ptr inttoptr (i64 16 to ptr)) diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 33928ac118e0c..168502f89cbf8 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -1385,7 +1385,7 @@ define ptr @foo() { // Check get(), getKey(), getDiscriminator(), getAddrDiscriminator(). auto *NewPtrAuth = sandboxir::ConstantPtrAuth::get( &F, PtrAuth->getKey(), PtrAuth->getDiscriminator(), - PtrAuth->getAddrDiscriminator()); + PtrAuth->getAddrDiscriminator(), PtrAuth->getDeactivationSymbol()); EXPECT_EQ(NewPtrAuth, PtrAuth); // Check hasAddressDiscriminator(). EXPECT_EQ(PtrAuth->hasAddressDiscriminator(), diff --git a/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp b/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp index 86ad41fa7ad50..90b58236060d2 100644 --- a/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp +++ b/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp @@ -450,6 +450,10 @@ TEST(ValueMapperTest, mapValuePtrAuth) { PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "Storage0"); std::unique_ptr Storage1 = std::make_unique( PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "Storage1"); + std::unique_ptr DS0 = std::make_unique( + PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "DS0"); + std::unique_ptr DS1 = std::make_unique( + PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "DS1"); ConstantInt *ConstKey = ConstantInt::get(Int32Ty, 1); ConstantInt *ConstDisc = ConstantInt::get(Int64Ty, 1234); @@ -457,11 +461,12 @@ TEST(ValueMapperTest, mapValuePtrAuth) { ValueToValueMapTy VM; VM[Var0.get()] = Var1.get(); VM[Storage0.get()] = Storage1.get(); + VM[DS0.get()] = DS1.get(); - ConstantPtrAuth *Value = - ConstantPtrAuth::get(Var0.get(), ConstKey, ConstDisc, Storage0.get()); - ConstantPtrAuth *MappedValue = - ConstantPtrAuth::get(Var1.get(), ConstKey, ConstDisc, Storage1.get()); + ConstantPtrAuth *Value = ConstantPtrAuth::get(Var0.get(), ConstKey, ConstDisc, + Storage0.get(), DS0.get()); + ConstantPtrAuth *MappedValue = ConstantPtrAuth::get( + Var1.get(), ConstKey, ConstDisc, Storage1.get(), DS1.get()); EXPECT_EQ(ValueMapper(VM).mapValue(*Value), MappedValue); }