@@ -456,6 +456,33 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
456456 declareTargetOp.setDeclareTarget (deviceType, captureClause);
457457}
458458
459+ // / For an operation that takes `omp.private` values as region args, this util
460+ // / merges the private vars info into the region arguments list.
461+ // /
462+ // / \tparam OMPOP - the OpenMP op that takes `omp.private` inputs.
463+ // / \tparam InfoTy - the type of private info we want to merge; e.g. mlir::Type
464+ // / or mlir::Location fields of the private var list.
465+ // /
466+ // / \param [in] op - the op accepting `omp.private` inputs.
467+ // / \param [in] currentList - the current list of region info that we
468+ // / want to merge private info with. For example this could be the list of types
469+ // / or locations of previous arguments to \op's region.
470+ // / \param [in] infoAccessor - for a private variable, this returns the
471+ // / data we want to merge: type or location.
472+ // / \param [out] allRegionArgsInfo - the merged list of region info.
473+ template <typename OMPOp, typename InfoTy>
474+ static void
475+ mergePrivateVarsInfo (OMPOp op, llvm::ArrayRef<InfoTy> currentList,
476+ llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
477+ llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
478+ mlir::OperandRange privateVars = op.getPrivateVars ();
479+
480+ llvm::transform (currentList, std::back_inserter (allRegionArgsInfo),
481+ [](InfoTy i) { return i; });
482+ llvm::transform (privateVars, std::back_inserter (allRegionArgsInfo),
483+ infoAccessor);
484+ }
485+
459486// ===----------------------------------------------------------------------===//
460487// Op body generation helper structures and functions
461488// ===----------------------------------------------------------------------===//
@@ -758,15 +785,28 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
758785 llvm::ArrayRef<const semantics::Symbol *> mapSyms,
759786 llvm::ArrayRef<mlir::Location> mapSymLocs,
760787 llvm::ArrayRef<mlir::Type> mapSymTypes,
788+ DataSharingProcessor &dsp,
761789 const mlir::Location ¤tLocation,
762790 const ConstructQueue &queue, ConstructQueue::iterator item) {
763791 assert (mapSymTypes.size () == mapSymLocs.size ());
764792
765793 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
766794 mlir::Region ®ion = targetOp.getRegion ();
767795
768- auto *regionBlock =
769- firOpBuilder.createBlock (®ion, {}, mapSymTypes, mapSymLocs);
796+ llvm::SmallVector<mlir::Type> allRegionArgTypes;
797+ mergePrivateVarsInfo (targetOp, mapSymTypes,
798+ llvm::function_ref<mlir::Type (mlir::Value)>{
799+ [](mlir::Value v) { return v.getType (); }},
800+ allRegionArgTypes);
801+
802+ llvm::SmallVector<mlir::Location> allRegionArgLocs;
803+ mergePrivateVarsInfo (targetOp, mapSymLocs,
804+ llvm::function_ref<mlir::Location (mlir::Value)>{
805+ [](mlir::Value v) { return v.getLoc (); }},
806+ allRegionArgLocs);
807+
808+ auto *regionBlock = firOpBuilder.createBlock (®ion, {}, allRegionArgTypes,
809+ allRegionArgLocs);
770810
771811 // Clones the `bounds` placing them inside the target region and returns them.
772812 auto cloneBound = [&](mlir::Value bound) {
@@ -830,6 +870,20 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
830870 });
831871 }
832872
873+ for (auto [argIndex, argSymbol] :
874+ llvm::enumerate (dsp.getAllSymbolsToPrivatize ())) {
875+ argIndex = mapSyms.size () + argIndex;
876+
877+ const mlir::BlockArgument &arg = region.getArgument (argIndex);
878+ converter.bindSymbol (*argSymbol,
879+ hlfir::translateToExtendedValue (
880+ currentLocation, firOpBuilder, hlfir::Entity{arg},
881+ /* contiguousHint=*/
882+ evaluate::IsSimplyContiguous (
883+ *argSymbol, converter.getFoldingContext ()))
884+ .first );
885+ }
886+
833887 // Check if cloning the bounds introduced any dependency on the outer region.
834888 // If so, then either clone them as well if they are MemoryEffectFree, or else
835889 // copy them to a new temporary and add them to the map and block_argument
@@ -907,6 +961,8 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
907961 } else {
908962 genNestedEvaluations (converter, eval);
909963 }
964+
965+ dsp.processStep2 (targetOp, /* isLoop=*/ false );
910966}
911967
912968template <typename OpTy, typename ... Args>
@@ -1048,15 +1104,18 @@ static void genTargetClauses(
10481104 devicePtrSyms);
10491105 cp.processMap (loc, stmtCtx, clauseOps, &mapSyms, &mapLocs, &mapTypes);
10501106 cp.processThreadLimit (stmtCtx, clauseOps);
1051- // TODO Support delayed privatization.
10521107
10531108 if (processHostOnlyClauses)
10541109 cp.processNowait (clauseOps);
10551110
10561111 cp.processTODO <clause::Allocate, clause::Defaultmap, clause::Firstprivate,
1057- clause::InReduction, clause::Private, clause:: Reduction,
1112+ clause::InReduction, clause::Reduction,
10581113 clause::UsesAllocators>(loc,
10591114 llvm::omp::Directive::OMPD_target);
1115+
1116+ // `target private(..)` is only supported in delayed privatization mode.
1117+ if (!enableDelayedPrivatization)
1118+ cp.processTODO <clause::Private>(loc, llvm::omp::Directive::OMPD_target);
10601119}
10611120
10621121static void genTargetDataClauses (
@@ -1289,7 +1348,6 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
12891348 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
12901349 lower::StatementContext stmtCtx;
12911350 mlir::omp::ParallelClauseOps clauseOps;
1292- llvm::SmallVector<const semantics::Symbol *> privateSyms;
12931351 llvm::SmallVector<mlir::Type> reductionTypes;
12941352 llvm::SmallVector<const semantics::Symbol *> reductionSyms;
12951353 genParallelClauses (converter, semaCtx, stmtCtx, item->clauses , loc,
@@ -1319,34 +1377,35 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
13191377 /* useDelayedPrivatization=*/ true , &symTable);
13201378
13211379 if (privatize)
1322- dsp.processStep1 (&clauseOps, &privateSyms );
1380+ dsp.processStep1 (&clauseOps);
13231381
13241382 auto genRegionEntryCB = [&](mlir::Operation *op) {
13251383 auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
13261384
13271385 llvm::SmallVector<mlir::Location> reductionLocs (
13281386 clauseOps.reductionVars .size (), loc);
13291387
1330- mlir::OperandRange privateVars = parallelOp.getPrivateVars ();
1331- mlir::Region ®ion = parallelOp.getRegion ();
1388+ llvm::SmallVector<mlir::Type> allRegionArgTypes;
1389+ mergePrivateVarsInfo (parallelOp, llvm::ArrayRef (reductionTypes),
1390+ llvm::function_ref<mlir::Type (mlir::Value)>{
1391+ [](mlir::Value v) { return v.getType (); }},
1392+ allRegionArgTypes);
13321393
1333- llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes;
1334- privateVarTypes.reserve (privateVarTypes.size () + privateVars.size ());
1335- llvm::transform (privateVars, std::back_inserter (privateVarTypes),
1336- [](mlir::Value v) { return v.getType (); });
1394+ llvm::SmallVector<mlir::Location> allRegionArgLocs;
1395+ mergePrivateVarsInfo (parallelOp, llvm::ArrayRef (reductionLocs),
1396+ llvm::function_ref<mlir::Location (mlir::Value)>{
1397+ [](mlir::Value v) { return v.getLoc (); }},
1398+ allRegionArgLocs);
13371399
1338- llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs;
1339- privateVarLocs.reserve (privateVarLocs.size () + privateVars.size ());
1340- llvm::transform (privateVars, std::back_inserter (privateVarLocs),
1341- [](mlir::Value v) { return v.getLoc (); });
1342-
1343- firOpBuilder.createBlock (®ion, /* insertPt=*/ {}, privateVarTypes,
1344- privateVarLocs);
1400+ mlir::Region ®ion = parallelOp.getRegion ();
1401+ firOpBuilder.createBlock (®ion, /* insertPt=*/ {}, allRegionArgTypes,
1402+ allRegionArgLocs);
13451403
13461404 llvm::SmallVector<const semantics::Symbol *> allSymbols = reductionSyms;
1347- allSymbols.append (privateSyms);
1405+ allSymbols.append (dsp.getAllSymbolsToPrivatize ().begin (),
1406+ dsp.getAllSymbolsToPrivatize ().end ());
1407+
13481408 for (auto [arg, prv] : llvm::zip_equal (allSymbols, region.getArguments ())) {
1349- fir::ExtendedValue hostExV = converter.getSymbolExtendedValue (*arg);
13501409 converter.bindSymbol (*arg, hlfir::translateToExtendedValue (
13511410 loc, firOpBuilder, hlfir::Entity{prv},
13521411 /* contiguousHint=*/
@@ -1541,11 +1600,22 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
15411600 deviceAddrLocs, deviceAddrTypes, devicePtrSyms,
15421601 devicePtrLocs, devicePtrTypes);
15431602
1603+ llvm::SmallVector<const semantics::Symbol *> privateSyms;
1604+ DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
1605+ /* shouldCollectPreDeterminedSymbols=*/
1606+ lower::omp::isLastItemInQueue (item, queue),
1607+ /* useDelayedPrivatization=*/ true , &symTable);
1608+ dsp.processStep1 (&clauseOps);
1609+
15441610 // 5.8.1 Implicit Data-Mapping Attribute Rules
15451611 // The following code follows the implicit data-mapping rules to map all the
1546- // symbols used inside the region that have not been explicitly mapped using
1547- // the map clause.
1612+ // symbols used inside the region that do not have explicit data-environment
1613+ // attribute clauses (neither data-sharing; e.g. `private`, nor `map`
1614+ // clauses).
15481615 auto captureImplicitMap = [&](const semantics::Symbol &sym) {
1616+ if (dsp.getAllSymbolsToPrivatize ().contains (&sym))
1617+ return ;
1618+
15491619 if (llvm::find (mapSyms, &sym) == mapSyms.end ()) {
15501620 mlir::Value baseOp = converter.getSymbolAddress (sym);
15511621 if (!baseOp)
@@ -1632,7 +1702,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
16321702
16331703 auto targetOp = firOpBuilder.create <mlir::omp::TargetOp>(loc, clauseOps);
16341704 genBodyOfTargetOp (converter, symTable, semaCtx, eval, targetOp, mapSyms,
1635- mapLocs, mapTypes, loc, queue, item);
1705+ mapLocs, mapTypes, dsp, loc, queue, item);
16361706 return targetOp;
16371707}
16381708
0 commit comments