diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index cb8a6e08886df..7928bb940ad8b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -521,7 +521,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [ Variadic:$callee_operands, Variadic:$normalDestOperands, Variadic:$unwindDestOperands, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights, + DefaultValuedAttr:$CConv); let results = (outs Variadic); let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); @@ -602,7 +603,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", Variadic:$callee_operands, DefaultValuedAttr:$fastmathFlags, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights, + DefaultValuedAttr:$CConv); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index c22cff4c1328a..6c8ea382b1dd8 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -97,6 +97,52 @@ static Type getI1SameShape(Type type) { return i1Type; } +// Parses one of the keywords provided in the list `keywords` and returns the +// position of the parsed keyword in the list. If none of the keywords from the +// list is parsed, returns -1. +static int parseOptionalKeywordAlternative(OpAsmParser &parser, + ArrayRef keywords) { + for (const auto &en : llvm::enumerate(keywords)) { + if (succeeded(parser.parseOptionalKeyword(en.value()))) + return en.index(); + } + return -1; +} + +namespace { +template +struct EnumTraits {}; + +#define REGISTER_ENUM_TYPE(Ty) \ + template <> \ + struct EnumTraits { \ + static StringRef stringify(Ty value) { return stringify##Ty(value); } \ + static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ + } + +REGISTER_ENUM_TYPE(Linkage); +REGISTER_ENUM_TYPE(UnnamedAddr); +REGISTER_ENUM_TYPE(CConv); +REGISTER_ENUM_TYPE(Visibility); +} // namespace + +/// Parse an enum from the keyword, or default to the provided default value. +/// The return type is the enum type by default, unless overridden with the +/// second template argument. +template +static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, + OperationState &result, + EnumTy defaultValue) { + SmallVector names; + for (unsigned i = 0, e = EnumTraits::getMaxEnumVal(); i <= e; ++i) + names.push_back(EnumTraits::stringify(static_cast(i))); + + int index = parseOptionalKeywordAlternative(parser, names); + if (index == -1) + return static_cast(defaultValue); + return static_cast(index); +} + //===----------------------------------------------------------------------===// // Printing, parsing, folding and builder for LLVM::CmpOp. //===----------------------------------------------------------------------===// @@ -859,6 +905,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, build(builder, state, results, TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)), callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, + /*CConv=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -880,7 +927,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, ValueRange args) { build(builder, state, getCallOpResultTypes(calleeType), TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr, - /*branch_weights=*/nullptr, /*access_groups=*/nullptr, + /*branch_weights=*/nullptr, /*CConv=*/nullptr, + /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -889,6 +937,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, build(builder, state, getCallOpResultTypes(calleeType), TypeAttr::get(calleeType), /*callee=*/nullptr, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, + /*CConv=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -899,9 +948,11 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, build(builder, state, getCallOpResultTypes(calleeType), TypeAttr::get(calleeType), SymbolRefAttr::get(func), args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, + /*CConv=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } + CallInterfaceCallable CallOp::getCallableForCallee() { // Direct call. if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) @@ -1054,9 +1105,14 @@ void CallOp::print(OpAsmPrinter &p) { isVarArg = calleeType.isVarArg(); } + p << ' '; + + // Print calling convention. + if (getCConv() != LLVM::CConv::C) + p << stringifyCConv(getCConv()) << ' '; + // Print the direct callee if present as a function attribute, or an indirect // callee (first operand) otherwise. - p << ' '; if (isDirect) p.printSymbolName(callee.value()); else @@ -1069,7 +1125,7 @@ void CallOp::print(OpAsmPrinter &p) { p << " vararg(" << calleeType << ")"; p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), - {"callee", "callee_type"}); + {getCConvAttrName(), "callee", "callee_type"}); p << " : "; if (!isDirect) @@ -1137,7 +1193,7 @@ static ParseResult parseOptionalCallFuncPtr( return success(); } -// ::= `llvm.call` (function-id | ssa-use) +// ::= `llvm.call` (cconv)? (function-id | ssa-use) // `(` ssa-use-list `)` // ( `vararg(` var-arg-func-type `)` )? // attribute-dict? `:` (type `,`)? function-type @@ -1146,6 +1202,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { TypeAttr calleeType; SmallVector operands; + // Default to C Calling Convention if no keyword is provided. + result.addAttribute( + getCConvAttrName(result.name), + CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword( + parser, result, LLVM::CConv::C))); + // Parse a function pointer for indirect calls. if (parseOptionalCallFuncPtr(parser, operands)) return failure(); @@ -1191,7 +1253,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, auto calleeType = func.getFunctionType(); build(builder, state, getCallOpResultTypes(calleeType), TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps, - unwindOps, nullptr, normal, unwind); + unwindOps, nullptr, nullptr, normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, @@ -1200,7 +1262,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, ValueRange unwindOps) { build(builder, state, tys, TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee, - ops, normalOps, unwindOps, nullptr, normal, unwind); + ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, @@ -1209,7 +1271,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, Block *unwind, ValueRange unwindOps) { build(builder, state, getCallOpResultTypes(calleeType), TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr, - normal, unwind); + nullptr, normal, unwind); } SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { @@ -1275,6 +1337,10 @@ void InvokeOp::print(OpAsmPrinter &p) { p << ' '; + // Print calling convention. + if (getCConv() != LLVM::CConv::C) + p << stringifyCConv(getCConv()) << ' '; + // Either function name or pointer if (isDirect) p.printSymbolName(callee.value()); @@ -1290,9 +1356,9 @@ void InvokeOp::print(OpAsmPrinter &p) { if (isVarArg) p << " vararg(" << calleeType << ")"; - p.printOptionalAttrDict( - (*this)->getAttrs(), - {InvokeOp::getOperandSegmentSizeAttr(), "callee", "callee_type"}); + p.printOptionalAttrDict((*this)->getAttrs(), + {InvokeOp::getOperandSegmentSizeAttr(), "callee", + "callee_type", InvokeOp::getCConvAttrName()}); p << " : "; if (!isDirect) @@ -1301,7 +1367,7 @@ void InvokeOp::print(OpAsmPrinter &p) { getResultTypes()); } -// ::= `llvm.invoke` (function-id | ssa-use) +// ::= `llvm.invoke` (cconv)? (function-id | ssa-use) // `(` ssa-use-list `)` // `to` bb-id (`[` ssa-use-and-type-list `]`)? // `unwind` bb-id (`[` ssa-use-and-type-list `]`)? @@ -1315,6 +1381,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector normalOperands, unwindOperands; Builder &builder = parser.getBuilder(); + // Default to C Calling Convention if no keyword is provided. + result.addAttribute( + getCConvAttrName(result.name), + CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword( + parser, result, LLVM::CConv::C))); + // Parse a function pointer for indirect calls. if (parseOptionalCallFuncPtr(parser, operands)) return failure(); @@ -1788,52 +1860,6 @@ void GlobalOp::print(OpAsmPrinter &p) { } } -// Parses one of the keywords provided in the list `keywords` and returns the -// position of the parsed keyword in the list. If none of the keywords from the -// list is parsed, returns -1. -static int parseOptionalKeywordAlternative(OpAsmParser &parser, - ArrayRef keywords) { - for (const auto &en : llvm::enumerate(keywords)) { - if (succeeded(parser.parseOptionalKeyword(en.value()))) - return en.index(); - } - return -1; -} - -namespace { -template -struct EnumTraits {}; - -#define REGISTER_ENUM_TYPE(Ty) \ - template <> \ - struct EnumTraits { \ - static StringRef stringify(Ty value) { return stringify##Ty(value); } \ - static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ - } - -REGISTER_ENUM_TYPE(Linkage); -REGISTER_ENUM_TYPE(UnnamedAddr); -REGISTER_ENUM_TYPE(CConv); -REGISTER_ENUM_TYPE(Visibility); -} // namespace - -/// Parse an enum from the keyword, or default to the provided default value. -/// The return type is the enum type by default, unless overriden with the -/// second template argument. -template -static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, - OperationState &result, - EnumTy defaultValue) { - SmallVector names; - for (unsigned i = 0, e = EnumTraits::getMaxEnumVal(); i <= e; ++i) - names.push_back(EnumTraits::stringify(static_cast(i))); - - int index = parseOptionalKeywordAlternative(parser, names); - if (index == -1) - return static_cast(defaultValue); - return static_cast(index); -} - static LogicalResult verifyComdat(Operation *op, std::optional attr) { if (!attr) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 1c0f51a66bf5e..5494a13acb6e1 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -200,6 +200,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, call = builder.CreateCall(calleeType, operandsRef.front(), operandsRef.drop_front()); } + call->setCallingConv(convertCConvToLLVM(callOp.getCConv())); moduleTranslation.setAccessGroupsMetadata(callOp, call); moduleTranslation.setAliasScopeMetadata(callOp, call); moduleTranslation.setTBAAMetadata(callOp, call); @@ -275,7 +276,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (auto invOp = dyn_cast(opInst)) { auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands()); ArrayRef operandsRef(operands); - llvm::Instruction *result; + llvm::InvokeInst *result; if (auto attr = opInst.getAttrOfType("callee")) { result = builder.CreateInvoke( moduleTranslation.lookupFunction(attr.getValue()), @@ -290,6 +291,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); } + result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result if (invOp->getNumResults() != 0) { diff --git a/mlir/test/Dialect/LLVMIR/calling-conventions.mlir b/mlir/test/Dialect/LLVMIR/calling-conventions.mlir new file mode 100644 index 0000000000000..153dfa18b4419 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/calling-conventions.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +llvm.func @__gxx_personality_v0(...) -> i32 + +// CHECK: declare fastcc void @cconv_fastcc() +// CHECK: declare void @cconv_ccc() +// CHECK: declare tailcc void @cconv_tailcc() +// CHECK: declare ghccc void @cconv_ghccc() +llvm.func fastcc @cconv_fastcc() +llvm.func ccc @cconv_ccc() +llvm.func tailcc @cconv_tailcc() +llvm.func cc_10 @cconv_ghccc() + +// CHECK-LABEL: @test_ccs +llvm.func @test_ccs() { + // CHECK-NEXT: call fastcc void @cconv_fastcc() + // CHECK-NEXT: call void @cconv_ccc() + // CHECK-NEXT: call void @cconv_ccc() + // CHECK-NEXT: call tailcc void @cconv_tailcc() + // CHECK-NEXT: call ghccc void @cconv_ghccc() + // CHECK-NEXT: ret void + llvm.call fastcc @cconv_fastcc() : () -> () + llvm.call ccc @cconv_ccc() : () -> () + llvm.call @cconv_ccc() : () -> () + llvm.call tailcc @cconv_tailcc() : () -> () + llvm.call cc_10 @cconv_ghccc() : () -> () + llvm.return +} + +// CHECK-LABEL: @test_ccs_invoke +llvm.func @test_ccs_invoke() attributes { personality = @__gxx_personality_v0 } { + // CHECK-NEXT: invoke fastcc void @cconv_fastcc() + // CHECK-NEXT: to label %[[normal1:[0-9]+]] unwind label %[[unwind:[0-9]+]] + llvm.invoke fastcc @cconv_fastcc() to ^bb1 unwind ^bb6 : () -> () + +^bb1: + // CHECK: [[normal1]]: + // CHECK-NEXT: invoke void @cconv_ccc() + // CHECK-NEXT: to label %[[normal2:[0-9]+]] unwind label %[[unwind:[0-9]+]] + llvm.invoke ccc @cconv_ccc() to ^bb2 unwind ^bb6 : () -> () + +^bb2: + // CHECK: [[normal2]]: + // CHECK-NEXT: invoke void @cconv_ccc() + // CHECK-NEXT: to label %[[normal3:[0-9]+]] unwind label %[[unwind:[0-9]+]] + llvm.invoke @cconv_ccc() to ^bb3 unwind ^bb6 : () -> () + +^bb3: + // CHECK: [[normal3]]: + // CHECK-NEXT: invoke tailcc void @cconv_tailcc() + // CHECK-NEXT: to label %[[normal4:[0-9]+]] unwind label %[[unwind:[0-9]+]] + llvm.invoke tailcc @cconv_tailcc() to ^bb4 unwind ^bb6 : () -> () + +^bb4: + // CHECK: [[normal4]]: + // CHECK-NEXT: invoke ghccc void @cconv_ghccc() + // CHECK-NEXT: to label %[[normal5:[0-9]+]] unwind label %[[unwind:[0-9]+]] + llvm.invoke cc_10 @cconv_ghccc() to ^bb5 unwind ^bb6 : () -> () + +^bb5: + // CHECK: [[normal5]]: + // CHECK-NEXT: ret void + llvm.return + + // CHECK: [[unwind]]: + // CHECK-NEXT: landingpad { ptr, i32 } + // CHECK-NEXT: cleanup + // CHECK-NEXT: ret void +^bb6: + %0 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.return +} diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir index b45e6c4ef897b..63e20b1d8fc31 100644 --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -184,6 +184,24 @@ module { llvm.return } + // CHECK: llvm.func cc_10 @cconv4 + llvm.func cc_10 @cconv4() { + llvm.return + } + + // CHECK: llvm.func @test_ccs + llvm.func @test_ccs() { + // CHECK-NEXT: llvm.call @cconv1() : () -> () + // CHECK-NEXT: llvm.call @cconv2() : () -> () + // CHECK-NEXT: llvm.call fastcc @cconv3() : () -> () + // CHECK-NEXT: llvm.call cc_10 @cconv4() : () -> () + llvm.call @cconv1() : () -> () + llvm.call ccc @cconv2() : () -> () + llvm.call fastcc @cconv3() : () -> () + llvm.call cc_10 @cconv4() : () -> () + llvm.return + } + // CHECK-LABEL: llvm.func @variadic_def llvm.func @variadic_def(...) { llvm.return diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir index 1296b8e031c13..b684be1f9626b 100644 --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -84,7 +84,7 @@ llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count = // CHECK-NEXT: llvm.return %[[CST]] llvm.func @caller() -> (i32) { // Include all call attributes that don't prevent inlining. - %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath, branch_weights = dense<42> : vector<1xi32> } : () -> (i32) + %0 = llvm.call fastcc @callee() { fastmathFlags = #llvm.fastmath, branch_weights = dense<42> : vector<1xi32> } : () -> (i32) llvm.return %0 : i32 }