- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
          [OpenMP][mlir] Added num_teams, thread_limit translation to LLVM IR
          #68821
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds translation to LLVM IR for `num_teams` and `thread_limit` in for `omp.teams` operation.
| @llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Shraiysh (shraiysh) ChangesThis patch adds translation to LLVM IR for  Full diff: https://github.com/llvm/llvm-project/pull/68821.diff 2 Files Affected: 
 diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1ec3bb8e7562a9e..ae974c14fac41a6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -667,11 +667,9 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
                 LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
-  if (op.getNumTeamsLower() || op.getNumTeamsUpper() || op.getIfExpr() ||
-      op.getThreadLimit() || !op.getAllocatorsVars().empty() ||
-      op.getReductions()) {
+  if (op.getIfExpr() || !op.getAllocatorsVars().empty() || op.getReductions())
     return op.emitError("unhandled clauses for translation to LLVM IR");
-  }
+
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
         moduleTranslation, allocaIP);
@@ -680,9 +678,21 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
                         moduleTranslation, bodyGenStatus);
   };
 
+  llvm::Value *numTeamsLower = nullptr;
+  if (auto numTeamsLowerVar = op.getNumTeamsLower())
+    numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
+
+  llvm::Value *numTeamsUpper = nullptr;
+  if (auto numTeamsUpperVar = op.getNumTeamsUpper())
+    numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
+
+  llvm::Value *threadLimit = nullptr;
+  if (auto threadLimitVar = op.getThreadLimit())
+    threadLimit = moduleTranslation.lookupValue(threadLimitVar);
+
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-  builder.restoreIP(
-      moduleTranslation.getOpenMPBuilder()->createTeams(ompLoc, bodyCB));
+  builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(
+      ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit));
   return bodyGenStatus;
 }
 
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index 18fc2bb5a3c61b2..87ef90223ed704a 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -124,3 +124,114 @@ llvm.func @omp_teams_branching_shared(%condition: i1, %arg0: i32, %arg1: f32, %a
 // CHECK-NEXT: br label
 // CHECK: ret void
 
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit
+// CHECK-SAME: (i32 [[THREAD_LIMIT:.+]])
+llvm.func @omp_teams_thread_limit(%threadLimit: i32) {
+    // CHECK-NEXT: call void @beforeTeams()
+    llvm.call @beforeTeams() : () -> ()
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[THREAD_LIMIT]])
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams thread_limit(%threadLimit : i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    // CHECK: call void @afterTeams
+    llvm.call @afterTeams() : () -> ()
+    // CHECK: ret void
+    llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_upper
+// CHECK-SAME: (i32 [[NUM_TEAMS_UPPER:.+]])
+llvm.func @omp_teams_num_teams_upper(%numTeamsUpper: i32) {
+    // CHECK-NEXT: call void @beforeTeams()
+    llvm.call @beforeTeams() : () -> ()
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_UPPER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams num_teams(to %numTeamsUpper : i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    // CHECK: call void @afterTeams
+    llvm.call @afterTeams() : () -> ()
+    // CHECK: ret void
+    llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_lower_and_upper
+// CHECK-SAME: (i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]])
+llvm.func @omp_teams_num_teams_lower_and_upper(%numTeamsLower: i32, %numTeamsUpper: i32) {
+    // CHECK-NEXT: call void @beforeTeams()
+    llvm.call @beforeTeams() : () -> ()
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams num_teams(%numTeamsLower : i32 to %numTeamsUpper: i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    // CHECK: call void @afterTeams
+    llvm.call @afterTeams() : () -> ()
+    // CHECK: ret void
+    llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_and_thread_limit
+// CHECK-SAME: (i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]], i32 [[THREAD_LIMIT:.+]])
+llvm.func @omp_teams_num_teams_and_thread_limit(%numTeamsLower: i32, %numTeamsUpper: i32, %threadLimit: i32) {
+    // CHECK-NEXT: call void @beforeTeams()
+    llvm.call @beforeTeams() : () -> ()
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 [[THREAD_LIMIT]])
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams num_teams(%numTeamsLower : i32 to %numTeamsUpper: i32) thread_limit(%threadLimit: i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    // CHECK: call void @afterTeams
+    llvm.call @afterTeams() : () -> ()
+    // CHECK: ret void
+    llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG.
| llvm::Value *numTeamsLower = nullptr; | ||
| if (auto numTeamsLowerVar = op.getNumTeamsLower()) | ||
| numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar); | ||
|  | ||
| llvm::Value *numTeamsUpper = nullptr; | ||
| if (auto numTeamsUpperVar = op.getNumTeamsUpper()) | ||
| numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar); | ||
|  | ||
| llvm::Value *threadLimit = nullptr; | ||
| if (auto threadLimitVar = op.getThreadLimit()) | ||
| threadLimit = moduleTranslation.lookupValue(threadLimitVar); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Spell the auto?
This patch adds translation to LLVM IR for
num_teamsandthread_limitin foromp.teamsoperation.