@@ -74,6 +74,12 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
7474  LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
7575}
7676
77+ /// Base class that defines BasicPtxBuilderOpInterface. 
78+ class NVVM_PTXBuilder_Op<string mnemonic, 
79+   list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
80+   LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
81+ }
82+ 
7783//===----------------------------------------------------------------------===//
7884// NVVM attribute definitions
7985//===----------------------------------------------------------------------===//
@@ -206,21 +212,31 @@ def NVVM_ReduxOp :
206212//===----------------------------------------------------------------------===//
207213
208214/// mbarrier.init instruction with generic pointer type
209- def NVVM_MBarrierInitOp : NVVM_Op <"mbarrier.init">,
210-   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
215+ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op <"mbarrier.init">,
216+   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate )> {
211217  string llvmBuilder = [{
212218      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
213219  }];
214-   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
220+   let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
221+   let extraClassDeclaration = [{
222+     bool hasIntrinsic() { if(getPredicate()) return false; return true; }
223+   }];
224+   let extraClassDefinition = [{
225+     std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
226+   }];
215227}
216228
217229/// mbarrier.init instruction with shared pointer type
218- def NVVM_MBarrierInitSharedOp : NVVM_Op <"mbarrier.init.shared">,
219-   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
230+ def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op <"mbarrier.init.shared">,
231+   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate )> {
220232  string llvmBuilder = [{
221233      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
222234  }];
223-   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
235+   let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
236+   let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
237+   let extraClassDefinition = [{
238+     std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
239+   }];
224240}
225241
226242def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
@@ -275,26 +291,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
275291  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
276292}
277293
278- def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
279-                     [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
280-   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
281-   let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
294+ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,  
295+   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
296+   let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
282297  let extraClassDefinition = [{
283298    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
284299  }];
285300}
286301
287- def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared", 
288-                     [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
289-   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {    
290-   let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
302+ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,  
303+   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {    
304+   let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
291305  let extraClassDefinition = [{
292306    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
293307  }];
294308}
295309
296- def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity", 
297-                     [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
310+ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,  
298311  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {  
299312  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
300313  let extraClassDefinition = [{
@@ -313,8 +326,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
313326  }];
314327}
315328
316- def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared", 
317-                     [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
329+ def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,  
318330  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {  
319331  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
320332  let extraClassDefinition = [{
@@ -488,7 +500,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
488500
489501def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
490502
491- def NVVM_CpAsyncOp : NVVM_Op <"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>] >,
503+ def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op <"cp.async.shared.global">,
492504  Arguments<(ins LLVM_i8Ptr_shared:$dst,
493505                 LLVM_i8Ptr_global:$src,
494506                 I32Attr:$size,
@@ -1359,12 +1371,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
13591371// NVVM TMA Ops
13601372//===----------------------------------------------------------------------===//
13611373
1362- def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1374+ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
1375+   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
1376+   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
1377+   AttrSizedOperandSegments]>,
13631378  Arguments<(ins  LLVM_i64ptr_shared:$dstMem,
13641379                  LLVM_i64ptr_any:$tmaDescriptor,
13651380                  LLVM_i64ptr_shared:$mbar,
1366-                   Variadic<I32>:$coordinates)> {
1367-   let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict  `:` type(operands)";
1381+                   Variadic<I32>:$coordinates,
1382+                   PtxPredicate:$predicate)> {
1383+   let assemblyFormat = [{ 
1384+     $dstMem `,` 
1385+     $tmaDescriptor `,` 
1386+     $mbar `,` 
1387+     `box` `[`$coordinates `]` 
1388+     (`,` `predicate` `=` $predicate^)? 
1389+     attr-dict  `:` type(operands)
1390+   }];
1391+ 
13681392  let extraClassDefinition = [{
13691393    std::string $cppClass::getPtx() {
13701394      int dim = getCoordinates().size();
@@ -1382,11 +1406,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
13821406  let hasVerifier = 1;
13831407}
13841408
1385- def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1409+ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : 
1410+   NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", 
1411+   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
1412+   AttrSizedOperandSegments]>,
13861413  Arguments<(ins  LLVM_i64ptr_any:$tmaDescriptor,
13871414                  LLVM_i64ptr_shared:$srcMem,
1388-                   Variadic<I32>:$coordinates)> {
1389-   let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict  `:` type(operands)";
1415+                   Variadic<I32>:$coordinates,
1416+                   PtxPredicate:$predicate)> {
1417+   let assemblyFormat = [{ 
1418+     $tmaDescriptor `,` 
1419+     $srcMem `,` 
1420+     `box` `[`$coordinates `]` 
1421+     (`,` `predicate` `=` $predicate^)?  
1422+     attr-dict  `:` type(operands)
1423+   }];
13901424  let extraClassDefinition = [{
13911425    std::string $cppClass::getPtx() {
13921426      int dim = getCoordinates().size();
@@ -1408,8 +1442,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
14081442// NVVM Wgmma Ops
14091443//===----------------------------------------------------------------------===//
14101444
1411- def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", 
1412-                     [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
1445+ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
14131446  let arguments = (ins);
14141447  let description = [{
14151448    Enforce an ordering of register accesses between warpgroup level matrix 
@@ -1423,8 +1456,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
14231456  }];
14241457}
14251458
1426- def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", 
1427-                     [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
1459+ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
14281460  Arguments<(ins )> {
14291461  let assemblyFormat = "attr-dict";
14301462  let description = [{
@@ -1437,8 +1469,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
14371469  }];
14381470}
14391471
1440- def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", 
1441-                     [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
1472+ def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
14421473  let arguments = (ins I32Attr:$group);
14431474  let assemblyFormat = "attr-dict $group";
14441475  let description = [{
0 commit comments