From 9e28dbde93d383e70f0949918ff5516e04a6009b Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Mon, 22 Jul 2024 15:18:02 -0700 Subject: [PATCH] [SandboxIR] Implement SelectInst This patch implements sandboxir::SelectInst which mirrors llvm::SelectInst. --- llvm/include/llvm/SandboxIR/SandboxIR.h | 58 ++++++++++++++++ .../llvm/SandboxIR/SandboxIRValues.def | 1 + llvm/lib/SandboxIR/SandboxIR.cpp | 63 +++++++++++++++++ llvm/unittests/SandboxIR/SandboxIRTest.cpp | 68 +++++++++++++++++++ 4 files changed, 190 insertions(+) diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index cd77897ccbb94..0c67206d307ef 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -75,6 +75,7 @@ class BasicBlock; class Context; class Function; class Instruction; +class SelectInst; class LoadInst; class ReturnInst; class StoreInst; @@ -177,6 +178,7 @@ class Value { friend class Context; // For getting `Val`. friend class User; // For getting `Val`. friend class Use; // For getting `Val`. + friend class SelectInst; // For getting `Val`. friend class LoadInst; // For getting `Val`. friend class StoreInst; // For getting `Val`. friend class ReturnInst; // For getting `Val`. @@ -411,6 +413,8 @@ class Constant : public sandboxir::User { } public: + static Constant *createInt(Type *Ty, uint64_t V, Context &Ctx, + bool IsSigned = false); /// For isa/dyn_cast. static bool classof(const sandboxir::Value *From) { return From->getSubclassID() == ClassID::Constant || @@ -499,6 +503,7 @@ class Instruction : public sandboxir::User { /// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This /// returns its topmost LLVM IR instruction. llvm::Instruction *getTopmostLLVMInstruction() const; + friend class SelectInst; // For getTopmostLLVMInstruction(). friend class LoadInst; // For getTopmostLLVMInstruction(). friend class StoreInst; // For getTopmostLLVMInstruction(). friend class ReturnInst; // For getTopmostLLVMInstruction(). @@ -566,6 +571,52 @@ class Instruction : public sandboxir::User { #endif }; +class SelectInst : public Instruction { + /// Use Context::createSelectInst(). Don't call the + /// constructor directly. + SelectInst(llvm::SelectInst *CI, Context &Ctx) + : Instruction(ClassID::Select, Opcode::Select, CI, Ctx) {} + friend Context; // for SelectInst() + Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final { + return getOperandUseDefault(OpIdx, Verify); + } + SmallVector getLLVMInstrs() const final { + return {cast(Val)}; + } + static Value *createCommon(Value *Cond, Value *True, Value *False, + const Twine &Name, IRBuilder<> &Builder, + Context &Ctx); + +public: + unsigned getUseOperandNo(const Use &Use) const final { + return getUseOperandNoDefault(Use); + } + unsigned getNumOfIRInstrs() const final { return 1u; } + static Value *create(Value *Cond, Value *True, Value *False, + Instruction *InsertBefore, Context &Ctx, + const Twine &Name = ""); + static Value *create(Value *Cond, Value *True, Value *False, + BasicBlock *InsertAtEnd, Context &Ctx, + const Twine &Name = ""); + Value *getCondition() { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + + void setCondition(Value *New) { setOperand(0, New); } + void setTrueValue(Value *New) { setOperand(1, New); } + void setFalseValue(Value *New) { setOperand(2, New); } + void swapValues() { cast(Val)->swapValues(); } + /// For isa/dyn_cast. + static bool classof(const Value *From); +#ifndef NDEBUG + void verify() const final { + assert(isa(Val) && "Expected SelectInst!"); + } + void dump(raw_ostream &OS) const override; + LLVM_DUMP_METHOD void dump() const override; +#endif +}; + class LoadInst final : public Instruction { /// Use LoadInst::create() instead of calling the constructor. LoadInst(llvm::LoadInst *LI, Context &Ctx) @@ -803,6 +854,11 @@ class Context { Value *getOrCreateValue(llvm::Value *LLVMV) { return getOrCreateValueInternal(LLVMV, 0); } + /// Get or create a sandboxir::Constant from an existing LLVM IR \p LLVMC. + Constant *getOrCreateConstant(llvm::Constant *LLVMC) { + return cast(getOrCreateValueInternal(LLVMC, 0)); + } + friend class Constant; // For getOrCreateConstant(). /// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will /// also create all contents of the block. BasicBlock *createBasicBlock(llvm::BasicBlock *BB); @@ -812,6 +868,8 @@ class Context { IRBuilder LLVMIRBuilder; auto &getLLVMIRBuilder() { return LLVMIRBuilder; } + SelectInst *createSelectInst(llvm::SelectInst *SI); + friend SelectInst; // For createSelectInst() LoadInst *createLoadInst(llvm::LoadInst *LI); friend LoadInst; // For createLoadInst() StoreInst *createStoreInst(llvm::StoreInst *SI); diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def index b2f88741af8d9..efa9155755587 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def +++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def @@ -25,6 +25,7 @@ DEF_USER(Constant, Constant) #endif // ClassID, Opcode(s), Class DEF_INSTR(Opaque, OP(Opaque), OpaqueInst) +DEF_INSTR(Select, OP(Select), SelectInst) DEF_INSTR(Load, OP(Load), LoadInst) DEF_INSTR(Store, OP(Store), StoreInst) DEF_INSTR(Ret, OP(Ret), ReturnInst) diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index 4cf45fa87693a..51c9af8a6e1fe 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -455,6 +455,51 @@ void Instruction::dump() const { } #endif // NDEBUG +Value *SelectInst::createCommon(Value *Cond, Value *True, Value *False, + const Twine &Name, IRBuilder<> &Builder, + Context &Ctx) { + llvm::Value *NewV = + Builder.CreateSelect(Cond->Val, True->Val, False->Val, Name); + if (auto *NewSI = dyn_cast(NewV)) + return Ctx.createSelectInst(NewSI); + assert(isa(NewV) && "Expected constant"); + return Ctx.getOrCreateConstant(cast(NewV)); +} + +Value *SelectInst::create(Value *Cond, Value *True, Value *False, + Instruction *InsertBefore, Context &Ctx, + const Twine &Name) { + llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction(); + auto &Builder = Ctx.getLLVMIRBuilder(); + Builder.SetInsertPoint(BeforeIR); + return createCommon(Cond, True, False, Name, Builder, Ctx); +} + +Value *SelectInst::create(Value *Cond, Value *True, Value *False, + BasicBlock *InsertAtEnd, Context &Ctx, + const Twine &Name) { + auto *IRInsertAtEnd = cast(InsertAtEnd->Val); + auto &Builder = Ctx.getLLVMIRBuilder(); + Builder.SetInsertPoint(IRInsertAtEnd); + return createCommon(Cond, True, False, Name, Builder, Ctx); +} + +bool SelectInst::classof(const Value *From) { + return From->getSubclassID() == ClassID::Select; +} + +#ifndef NDEBUG +void SelectInst::dump(raw_ostream &OS) const { + dumpCommonPrefix(OS); + dumpCommonSuffix(OS); +} + +void SelectInst::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align, Instruction *InsertBefore, Context &Ctx, const Twine &Name) { @@ -592,7 +637,15 @@ void OpaqueInst::dump() const { dump(dbgs()); dbgs() << "\n"; } +#endif // NDEBUG + +Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx, + bool IsSigned) { + llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned); + return Ctx.getOrCreateConstant(LLVMC); +} +#ifndef NDEBUG void Constant::dump(raw_ostream &OS) const { dumpCommonPrefix(OS); dumpCommonSuffix(OS); @@ -700,6 +753,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { assert(isa(LLVMV) && "Expected Instruction"); switch (cast(LLVMV)->getOpcode()) { + case llvm::Instruction::Select: { + auto *LLVMSel = cast(LLVMV); + It->second = std::unique_ptr(new SelectInst(LLVMSel, *this)); + return It->second.get(); + } case llvm::Instruction::Load: { auto *LLVMLd = cast(LLVMV); It->second = std::unique_ptr(new LoadInst(LLVMLd, *this)); @@ -733,6 +791,11 @@ BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) { return BB; } +SelectInst *Context::createSelectInst(llvm::SelectInst *SI) { + auto NewPtr = std::unique_ptr(new SelectInst(SI, *this)); + return cast(registerValue(std::move(NewPtr))); +} + LoadInst *Context::createLoadInst(llvm::LoadInst *LI) { auto NewPtr = std::unique_ptr(new LoadInst(LI, *this)); return cast(registerValue(std::move(NewPtr))); diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index b0d6ae85950d7..ba90b4f811f8e 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -561,6 +561,74 @@ define void @foo(i8 %v1) { EXPECT_EQ(I0->getNextNode(), Ret); } +TEST_F(SandboxIRTest, SelectInst) { + parseIR(C, R"IR( +define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) { + %sel = select i1 %c0, i8 %v0, i8 %v1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + sandboxir::Function *F = Ctx.createFunction(LLVMF); + auto *Cond0 = F->getArg(0); + auto *V0 = F->getArg(1); + auto *V1 = F->getArg(2); + auto *Cond1 = F->getArg(3); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *Select = cast(&*It++); + auto *Ret = &*It++; + + // Check getCondition(). + EXPECT_EQ(Select->getCondition(), Cond0); + // Check getTrueValue(). + EXPECT_EQ(Select->getTrueValue(), V0); + // Check getFalseValue(). + EXPECT_EQ(Select->getFalseValue(), V1); + // Check setCondition(). + Select->setCondition(Cond1); + EXPECT_EQ(Select->getCondition(), Cond1); + // Check setTrueValue(). + Select->setTrueValue(V1); + EXPECT_EQ(Select->getTrueValue(), V1); + // Check setFalseValue(). + Select->setFalseValue(V0); + EXPECT_EQ(Select->getFalseValue(), V0); + + { + // Check SelectInst::create() InsertBefore. + auto *NewSel = cast(sandboxir::SelectInst::create( + Cond0, V0, V1, /*InsertBefore=*/Ret, Ctx)); + EXPECT_EQ(NewSel->getCondition(), Cond0); + EXPECT_EQ(NewSel->getTrueValue(), V0); + EXPECT_EQ(NewSel->getFalseValue(), V1); + EXPECT_EQ(NewSel->getNextNode(), Ret); + } + { + // Check SelectInst::create() InsertAtEnd. + auto *NewSel = cast( + sandboxir::SelectInst::create(Cond0, V0, V1, /*InsertAtEnd=*/BB, Ctx)); + EXPECT_EQ(NewSel->getCondition(), Cond0); + EXPECT_EQ(NewSel->getTrueValue(), V0); + EXPECT_EQ(NewSel->getFalseValue(), V1); + EXPECT_EQ(NewSel->getPrevNode(), Ret); + } + { + // Check SelectInst::create() Folded. + auto *False = + sandboxir::Constant::createInt(llvm::Type::getInt1Ty(C), 0, Ctx, + /*IsSigned=*/false); + auto *FortyTwo = + sandboxir::Constant::createInt(llvm::Type::getInt1Ty(C), 42, Ctx, + /*IsSigned=*/false); + auto *NewSel = + sandboxir::SelectInst::create(False, FortyTwo, FortyTwo, Ret, Ctx); + EXPECT_TRUE(isa(NewSel)); + EXPECT_EQ(NewSel, FortyTwo); + } +} + TEST_F(SandboxIRTest, LoadInst) { parseIR(C, R"IR( define void @foo(ptr %arg0, ptr %arg1) {