1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to convert scf.parallel operations into OpenMP
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/SymbolTable.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
30 #include "mlir/Conversion/Passes.h.inc"
35 /// Matches a block containing a "simple" reduction. The expected shape of the
36 /// block is as follows.
38 /// ^bb(%arg0, %arg1):
39 /// %0 = OpTy(%arg0, %arg1)
40 /// scf.reduce.return %0
41 template <typename
... OpTy
>
42 static bool matchSimpleReduction(Block
&block
) {
43 if (block
.empty() || llvm::hasSingleElement(block
) ||
44 std::next(block
.begin(), 2) != block
.end())
47 if (block
.getNumArguments() != 2)
50 SmallVector
<Operation
*, 4> combinerOps
;
51 Value reducedVal
= matchReduction({block
.getArguments()[1]},
52 /*redPos=*/0, combinerOps
);
54 if (!reducedVal
|| !isa
<BlockArgument
>(reducedVal
) || combinerOps
.size() != 1)
57 return isa
<OpTy
...>(combinerOps
[0]) &&
58 isa
<scf::ReduceReturnOp
>(block
.back()) &&
59 block
.front().getOperands() == block
.getArguments();
62 /// Matches a block containing a select-based min/max reduction. The types of
63 /// select and compare operations are provided as template arguments. The
64 /// comparison predicates suitable for min and max are provided as function
65 /// arguments. If a reduction is matched, `ifMin` will be set if the reduction
66 /// compute the minimum and unset if it computes the maximum, otherwise it
67 /// remains unmodified. The expected shape of the block is as follows.
69 /// ^bb(%arg0, %arg1):
70 /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
71 /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
72 /// scf.reduce.return %1
74 typename CompareOpTy
, typename SelectOpTy
,
75 typename Predicate
= decltype(std::declval
<CompareOpTy
>().getPredicate())>
77 matchSelectReduction(Block
&block
, ArrayRef
<Predicate
> lessThanPredicates
,
78 ArrayRef
<Predicate
> greaterThanPredicates
, bool &isMin
) {
80 llvm::is_one_of
<SelectOpTy
, arith::SelectOp
, LLVM::SelectOp
>::value
,
81 "only arithmetic and llvm select ops are supported");
83 // Expect exactly three operations in the block.
84 if (block
.empty() || llvm::hasSingleElement(block
) ||
85 std::next(block
.begin(), 2) == block
.end() ||
86 std::next(block
.begin(), 3) != block
.end())
90 auto compare
= dyn_cast
<CompareOpTy
>(block
.front());
91 auto select
= dyn_cast
<SelectOpTy
>(block
.front().getNextNode());
92 auto terminator
= dyn_cast
<scf::ReduceReturnOp
>(block
.back());
93 if (!compare
|| !select
|| !terminator
)
96 // Block arguments must be compared.
97 if (compare
->getOperands() != block
.getArguments())
100 // Detect whether the comparison is less-than or greater-than, otherwise bail.
102 if (llvm::is_contained(lessThanPredicates
, compare
.getPredicate())) {
104 } else if (llvm::is_contained(greaterThanPredicates
,
105 compare
.getPredicate())) {
111 if (select
.getCondition() != compare
.getResult())
114 // Detect if the operands are swapped between cmpf and select. Match the
115 // comparison type with the requested type or with the opposite of the
116 // requested type if the operands are swapped. Use generic accessors because
117 // std and LLVM versions of select have different operand names but identical
119 constexpr unsigned kTrueValue
= 1;
120 constexpr unsigned kFalseValue
= 2;
121 bool sameOperands
= select
.getOperand(kTrueValue
) == compare
.getLhs() &&
122 select
.getOperand(kFalseValue
) == compare
.getRhs();
123 bool swappedOperands
= select
.getOperand(kTrueValue
) == compare
.getRhs() &&
124 select
.getOperand(kFalseValue
) == compare
.getLhs();
125 if (!sameOperands
&& !swappedOperands
)
128 if (select
.getResult() != terminator
.getResult())
131 // The reduction is a min if it uses less-than predicates with same operands
132 // or greather-than predicates with swapped operands. Similarly for max.
133 isMin
= (isLess
&& sameOperands
) || (!isLess
&& swappedOperands
);
134 return isMin
|| (isLess
& swappedOperands
) || (!isLess
&& sameOperands
);
137 /// Returns the float semantics for the given float type.
138 static const llvm::fltSemantics
&fltSemanticsForType(FloatType type
) {
140 return llvm::APFloat::IEEEhalf();
142 return llvm::APFloat::IEEEsingle();
144 return llvm::APFloat::IEEEdouble();
146 return llvm::APFloat::IEEEquad();
148 return llvm::APFloat::BFloat();
150 return llvm::APFloat::x87DoubleExtended();
151 llvm_unreachable("unknown float type");
154 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
155 /// (otherwise) for the given float type.
156 static Attribute
minMaxValueForFloat(Type type
, bool min
) {
157 auto fltType
= cast
<FloatType
>(type
);
158 return FloatAttr::get(
159 type
, llvm::APFloat::getLargest(fltSemanticsForType(fltType
), min
));
162 /// Returns an attribute with the signed integer minimum (if `min` is set) or
163 /// the maximum value (otherwise) for the given integer type, regardless of its
164 /// signedness semantics (only the width is considered).
165 static Attribute
minMaxValueForSignedInt(Type type
, bool min
) {
166 auto intType
= cast
<IntegerType
>(type
);
167 unsigned bitwidth
= intType
.getWidth();
168 return IntegerAttr::get(type
, min
? llvm::APInt::getSignedMinValue(bitwidth
)
169 : llvm::APInt::getSignedMaxValue(bitwidth
));
172 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
173 /// the maximum value (otherwise) for the given integer type, regardless of its
174 /// signedness semantics (only the width is considered).
175 static Attribute
minMaxValueForUnsignedInt(Type type
, bool min
) {
176 auto intType
= cast
<IntegerType
>(type
);
177 unsigned bitwidth
= intType
.getWidth();
178 return IntegerAttr::get(type
, min
? llvm::APInt::getZero(bitwidth
)
179 : llvm::APInt::getAllOnes(bitwidth
));
182 /// Creates an OpenMP reduction declaration and inserts it into the provided
183 /// symbol table. The declaration has a constant initializer with the neutral
184 /// value `initValue`, and the `reductionIndex`-th reduction combiner carried
185 /// over from `reduce`.
186 static omp::DeclareReductionOp
187 createDecl(PatternRewriter
&builder
, SymbolTable
&symbolTable
,
188 scf::ReduceOp reduce
, int64_t reductionIndex
, Attribute initValue
) {
189 OpBuilder::InsertionGuard
guard(builder
);
190 Type type
= reduce
.getOperands()[reductionIndex
].getType();
191 auto decl
= builder
.create
<omp::DeclareReductionOp
>(reduce
.getLoc(),
192 "__scf_reduction", type
);
193 symbolTable
.insert(decl
);
195 builder
.createBlock(&decl
.getInitializerRegion(),
196 decl
.getInitializerRegion().end(), {type
},
197 {reduce
.getOperands()[reductionIndex
].getLoc()});
198 builder
.setInsertionPointToEnd(&decl
.getInitializerRegion().back());
200 builder
.create
<LLVM::ConstantOp
>(reduce
.getLoc(), type
, initValue
);
201 builder
.create
<omp::YieldOp
>(reduce
.getLoc(), init
);
203 Operation
*terminator
=
204 &reduce
.getReductions()[reductionIndex
].front().back();
205 assert(isa
<scf::ReduceReturnOp
>(terminator
) &&
206 "expected reduce op to be terminated by redure return");
207 builder
.setInsertionPoint(terminator
);
208 builder
.replaceOpWithNewOp
<omp::YieldOp
>(terminator
,
209 terminator
->getOperands());
210 builder
.inlineRegionBefore(reduce
.getReductions()[reductionIndex
],
211 decl
.getReductionRegion(),
212 decl
.getReductionRegion().end());
216 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
217 /// using llvm.atomicrmw of the given kind.
218 static omp::DeclareReductionOp
addAtomicRMW(OpBuilder
&builder
,
219 LLVM::AtomicBinOp atomicKind
,
220 omp::DeclareReductionOp decl
,
221 scf::ReduceOp reduce
,
222 int64_t reductionIndex
) {
223 OpBuilder::InsertionGuard
guard(builder
);
224 auto ptrType
= LLVM::LLVMPointerType::get(builder
.getContext());
225 Location reduceOperandLoc
= reduce
.getOperands()[reductionIndex
].getLoc();
226 builder
.createBlock(&decl
.getAtomicReductionRegion(),
227 decl
.getAtomicReductionRegion().end(), {ptrType
, ptrType
},
228 {reduceOperandLoc
, reduceOperandLoc
});
229 Block
*atomicBlock
= &decl
.getAtomicReductionRegion().back();
230 builder
.setInsertionPointToEnd(atomicBlock
);
231 Value loaded
= builder
.create
<LLVM::LoadOp
>(reduce
.getLoc(), decl
.getType(),
232 atomicBlock
->getArgument(1));
233 builder
.create
<LLVM::AtomicRMWOp
>(reduce
.getLoc(), atomicKind
,
234 atomicBlock
->getArgument(0), loaded
,
235 LLVM::AtomicOrdering::monotonic
);
236 builder
.create
<omp::YieldOp
>(reduce
.getLoc(), ArrayRef
<Value
>());
240 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
241 /// reduction and returns it. Recognizes common reductions in order to identify
242 /// the neutral value, necessary for the OpenMP declaration. If the reduction
243 /// cannot be recognized, returns null.
244 static omp::DeclareReductionOp
declareReduction(PatternRewriter
&builder
,
245 scf::ReduceOp reduce
,
246 int64_t reductionIndex
) {
247 Operation
*container
= SymbolTable::getNearestSymbolTable(reduce
);
248 SymbolTable
symbolTable(container
);
250 // Insert reduction declarations in the symbol-table ancestor before the
251 // ancestor of the current insertion point.
252 Operation
*insertionPoint
= reduce
;
253 while (insertionPoint
->getParentOp() != container
)
254 insertionPoint
= insertionPoint
->getParentOp();
255 OpBuilder::InsertionGuard
guard(builder
);
256 builder
.setInsertionPoint(insertionPoint
);
258 assert(llvm::hasSingleElement(reduce
.getReductions()[reductionIndex
]) &&
259 "expected reduction region to have a single element");
261 // Match simple binary reductions that can be expressed with atomicrmw.
262 Type type
= reduce
.getOperands()[reductionIndex
].getType();
263 Block
&reduction
= reduce
.getReductions()[reductionIndex
].front();
264 if (matchSimpleReduction
<arith::AddFOp
, LLVM::FAddOp
>(reduction
)) {
265 omp::DeclareReductionOp decl
=
266 createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
267 builder
.getFloatAttr(type
, 0.0));
268 return addAtomicRMW(builder
, LLVM::AtomicBinOp::fadd
, decl
, reduce
,
271 if (matchSimpleReduction
<arith::AddIOp
, LLVM::AddOp
>(reduction
)) {
272 omp::DeclareReductionOp decl
=
273 createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
274 builder
.getIntegerAttr(type
, 0));
275 return addAtomicRMW(builder
, LLVM::AtomicBinOp::add
, decl
, reduce
,
278 if (matchSimpleReduction
<arith::OrIOp
, LLVM::OrOp
>(reduction
)) {
279 omp::DeclareReductionOp decl
=
280 createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
281 builder
.getIntegerAttr(type
, 0));
282 return addAtomicRMW(builder
, LLVM::AtomicBinOp::_or
, decl
, reduce
,
285 if (matchSimpleReduction
<arith::XOrIOp
, LLVM::XOrOp
>(reduction
)) {
286 omp::DeclareReductionOp decl
=
287 createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
288 builder
.getIntegerAttr(type
, 0));
289 return addAtomicRMW(builder
, LLVM::AtomicBinOp::_xor
, decl
, reduce
,
292 if (matchSimpleReduction
<arith::AndIOp
, LLVM::AndOp
>(reduction
)) {
293 omp::DeclareReductionOp decl
= createDecl(
294 builder
, symbolTable
, reduce
, reductionIndex
,
295 builder
.getIntegerAttr(
296 type
, llvm::APInt::getAllOnes(type
.getIntOrFloatBitWidth())));
297 return addAtomicRMW(builder
, LLVM::AtomicBinOp::_and
, decl
, reduce
,
301 // Match simple binary reductions that cannot be expressed with atomicrmw.
302 // TODO: add atomic region using cmpxchg (which needs atomic load to be
303 // available as an op).
304 if (matchSimpleReduction
<arith::MulFOp
, LLVM::FMulOp
>(reduction
)) {
305 return createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
306 builder
.getFloatAttr(type
, 1.0));
308 if (matchSimpleReduction
<arith::MulIOp
, LLVM::MulOp
>(reduction
)) {
309 return createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
310 builder
.getIntegerAttr(type
, 1));
313 // Match select-based min/max reductions.
315 if (matchSelectReduction
<arith::CmpFOp
, arith::SelectOp
>(
316 reduction
, {arith::CmpFPredicate::OLT
, arith::CmpFPredicate::OLE
},
317 {arith::CmpFPredicate::OGT
, arith::CmpFPredicate::OGE
}, isMin
) ||
318 matchSelectReduction
<LLVM::FCmpOp
, LLVM::SelectOp
>(
319 reduction
, {LLVM::FCmpPredicate::olt
, LLVM::FCmpPredicate::ole
},
320 {LLVM::FCmpPredicate::ogt
, LLVM::FCmpPredicate::oge
}, isMin
)) {
321 return createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
322 minMaxValueForFloat(type
, !isMin
));
324 if (matchSelectReduction
<arith::CmpIOp
, arith::SelectOp
>(
325 reduction
, {arith::CmpIPredicate::slt
, arith::CmpIPredicate::sle
},
326 {arith::CmpIPredicate::sgt
, arith::CmpIPredicate::sge
}, isMin
) ||
327 matchSelectReduction
<LLVM::ICmpOp
, LLVM::SelectOp
>(
328 reduction
, {LLVM::ICmpPredicate::slt
, LLVM::ICmpPredicate::sle
},
329 {LLVM::ICmpPredicate::sgt
, LLVM::ICmpPredicate::sge
}, isMin
)) {
330 omp::DeclareReductionOp decl
=
331 createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
332 minMaxValueForSignedInt(type
, !isMin
));
333 return addAtomicRMW(builder
,
334 isMin
? LLVM::AtomicBinOp::min
: LLVM::AtomicBinOp::max
,
335 decl
, reduce
, reductionIndex
);
337 if (matchSelectReduction
<arith::CmpIOp
, arith::SelectOp
>(
338 reduction
, {arith::CmpIPredicate::ult
, arith::CmpIPredicate::ule
},
339 {arith::CmpIPredicate::ugt
, arith::CmpIPredicate::uge
}, isMin
) ||
340 matchSelectReduction
<LLVM::ICmpOp
, LLVM::SelectOp
>(
341 reduction
, {LLVM::ICmpPredicate::ugt
, LLVM::ICmpPredicate::ule
},
342 {LLVM::ICmpPredicate::ugt
, LLVM::ICmpPredicate::uge
}, isMin
)) {
343 omp::DeclareReductionOp decl
=
344 createDecl(builder
, symbolTable
, reduce
, reductionIndex
,
345 minMaxValueForUnsignedInt(type
, !isMin
));
347 builder
, isMin
? LLVM::AtomicBinOp::umin
: LLVM::AtomicBinOp::umax
,
348 decl
, reduce
, reductionIndex
);
356 struct ParallelOpLowering
: public OpRewritePattern
<scf::ParallelOp
> {
357 static constexpr unsigned kUseOpenMPDefaultNumThreads
= 0;
360 ParallelOpLowering(MLIRContext
*context
,
361 unsigned numThreads
= kUseOpenMPDefaultNumThreads
)
362 : OpRewritePattern
<scf::ParallelOp
>(context
), numThreads(numThreads
) {}
364 LogicalResult
matchAndRewrite(scf::ParallelOp parallelOp
,
365 PatternRewriter
&rewriter
) const override
{
366 // Declare reductions.
367 // TODO: consider checking it here is already a compatible reduction
368 // declaration and use it instead of redeclaring.
369 SmallVector
<Attribute
> reductionSyms
;
370 SmallVector
<omp::DeclareReductionOp
> ompReductionDecls
;
371 auto reduce
= cast
<scf::ReduceOp
>(parallelOp
.getBody()->getTerminator());
372 for (int64_t i
= 0, e
= parallelOp
.getNumReductions(); i
< e
; ++i
) {
373 omp::DeclareReductionOp decl
= declareReduction(rewriter
, reduce
, i
);
374 ompReductionDecls
.push_back(decl
);
377 reductionSyms
.push_back(
378 SymbolRefAttr::get(rewriter
.getContext(), decl
.getSymName()));
381 // Allocate reduction variables. Make sure the we don't overflow the stack
382 // with local `alloca`s by saving and restoring the stack pointer.
383 Location loc
= parallelOp
.getLoc();
384 Value one
= rewriter
.create
<LLVM::ConstantOp
>(
385 loc
, rewriter
.getIntegerType(64), rewriter
.getI64IntegerAttr(1));
386 SmallVector
<Value
> reductionVariables
;
387 reductionVariables
.reserve(parallelOp
.getNumReductions());
388 auto ptrType
= LLVM::LLVMPointerType::get(parallelOp
.getContext());
389 for (Value init
: parallelOp
.getInitVals()) {
390 assert((LLVM::isCompatibleType(init
.getType()) ||
391 isa
<LLVM::PointerElementTypeInterface
>(init
.getType())) &&
392 "cannot create a reduction variable if the type is not an LLVM "
395 rewriter
.create
<LLVM::AllocaOp
>(loc
, ptrType
, init
.getType(), one
, 0);
396 rewriter
.create
<LLVM::StoreOp
>(loc
, init
, storage
);
397 reductionVariables
.push_back(storage
);
400 // Replace the reduction operations contained in this loop. Must be done
401 // here rather than in a separate pattern to have access to the list of
402 // reduction variables.
403 for (auto [x
, y
, rD
] : llvm::zip_equal(
404 reductionVariables
, reduce
.getOperands(), ompReductionDecls
)) {
405 OpBuilder::InsertionGuard
guard(rewriter
);
406 rewriter
.setInsertionPoint(reduce
);
407 Region
&redRegion
= rD
.getReductionRegion();
408 // The SCF dialect by definition contains only structured operations
409 // and hence the SCF reduction region will contain a single block.
410 // The ompReductionDecls region is a copy of the SCF reduction region
411 // and hence has the same property.
412 assert(redRegion
.hasOneBlock() &&
413 "expect reduction region to have one block");
414 Value pvtRedVar
= parallelOp
.getRegion().addArgument(x
.getType(), loc
);
415 Value pvtRedVal
= rewriter
.create
<LLVM::LoadOp
>(reduce
.getLoc(),
416 rD
.getType(), pvtRedVar
);
417 // Make a copy of the reduction combiner region in the body
418 mlir::OpBuilder
builder(rewriter
.getContext());
419 builder
.setInsertionPoint(reduce
);
420 mlir::IRMapping mapper
;
421 assert(redRegion
.getNumArguments() == 2 &&
422 "expect reduction region to have two arguments");
423 mapper
.map(redRegion
.getArgument(0), pvtRedVal
);
424 mapper
.map(redRegion
.getArgument(1), y
);
425 for (auto &op
: redRegion
.getOps()) {
426 Operation
*cloneOp
= builder
.clone(op
, mapper
);
427 if (auto yieldOp
= dyn_cast
<omp::YieldOp
>(*cloneOp
)) {
428 assert(yieldOp
&& yieldOp
.getResults().size() == 1 &&
429 "expect YieldOp in reduction region to return one result");
430 Value redVal
= yieldOp
.getResults()[0];
431 rewriter
.create
<LLVM::StoreOp
>(loc
, redVal
, pvtRedVar
);
432 rewriter
.eraseOp(yieldOp
);
437 rewriter
.eraseOp(reduce
);
440 if (numThreads
> 0) {
441 numThreadsVar
= rewriter
.create
<LLVM::ConstantOp
>(
442 loc
, rewriter
.getI32IntegerAttr(numThreads
));
444 // Create the parallel wrapper.
445 auto ompParallel
= rewriter
.create
<omp::ParallelOp
>(
447 /* allocate_vars = */ llvm::SmallVector
<Value
>{},
448 /* allocator_vars = */ llvm::SmallVector
<Value
>{},
449 /* if_expr = */ Value
{},
450 /* num_threads = */ numThreadsVar
,
451 /* private_vars = */ ValueRange(),
452 /* private_syms = */ nullptr,
453 /* proc_bind_kind = */ omp::ClauseProcBindKindAttr
{},
454 /* reduction_mod = */ nullptr,
455 /* reduction_vars = */ llvm::SmallVector
<Value
>{},
456 /* reduction_byref = */ DenseBoolArrayAttr
{},
457 /* reduction_syms = */ ArrayAttr
{});
460 OpBuilder::InsertionGuard
guard(rewriter
);
461 rewriter
.createBlock(&ompParallel
.getRegion());
465 OpBuilder::InsertionGuard
allocaGuard(rewriter
);
466 // Create worksharing loop wrapper.
467 auto wsloopOp
= rewriter
.create
<omp::WsloopOp
>(parallelOp
.getLoc());
468 if (!reductionVariables
.empty()) {
469 wsloopOp
.setReductionSymsAttr(
470 ArrayAttr::get(rewriter
.getContext(), reductionSyms
));
471 wsloopOp
.getReductionVarsMutable().append(reductionVariables
);
472 llvm::SmallVector
<bool> reductionByRef
;
473 // false because these reductions always reduce scalars and so do
474 // not need to pass by reference
475 reductionByRef
.resize(reductionVariables
.size(), false);
476 wsloopOp
.setReductionByref(
477 DenseBoolArrayAttr::get(rewriter
.getContext(), reductionByRef
));
479 rewriter
.create
<omp::TerminatorOp
>(loc
); // omp.parallel terminator.
481 // The wrapper's entry block arguments will define the reduction
483 llvm::SmallVector
<mlir::Type
> reductionTypes
;
484 reductionTypes
.reserve(reductionVariables
.size());
485 llvm::transform(reductionVariables
, std::back_inserter(reductionTypes
),
486 [](mlir::Value v
) { return v
.getType(); });
487 rewriter
.createBlock(
488 &wsloopOp
.getRegion(), {}, reductionTypes
,
489 llvm::SmallVector
<mlir::Location
>(reductionVariables
.size(),
490 parallelOp
.getLoc()));
492 // Create loop nest and populate region with contents of scf.parallel.
493 auto loopOp
= rewriter
.create
<omp::LoopNestOp
>(
494 parallelOp
.getLoc(), parallelOp
.getLowerBound(),
495 parallelOp
.getUpperBound(), parallelOp
.getStep());
497 rewriter
.inlineRegionBefore(parallelOp
.getRegion(), loopOp
.getRegion(),
498 loopOp
.getRegion().begin());
500 // Remove reduction-related block arguments from omp.loop_nest and
501 // redirect uses to the corresponding omp.wsloop block argument.
502 mlir::Block
&loopOpEntryBlock
= loopOp
.getRegion().front();
503 unsigned numLoops
= parallelOp
.getNumLoops();
504 rewriter
.replaceAllUsesWith(
505 loopOpEntryBlock
.getArguments().drop_front(numLoops
),
506 wsloopOp
.getRegion().getArguments());
507 loopOpEntryBlock
.eraseArguments(
508 numLoops
, loopOpEntryBlock
.getNumArguments() - numLoops
);
511 rewriter
.splitBlock(&loopOpEntryBlock
, loopOpEntryBlock
.begin());
512 rewriter
.setInsertionPointToStart(&loopOpEntryBlock
);
514 auto scope
= rewriter
.create
<memref::AllocaScopeOp
>(parallelOp
.getLoc(),
516 rewriter
.create
<omp::YieldOp
>(loc
, ValueRange());
517 Block
*scopeBlock
= rewriter
.createBlock(&scope
.getBodyRegion());
518 rewriter
.mergeBlocks(ops
, scopeBlock
);
519 rewriter
.setInsertionPointToEnd(&*scope
.getBodyRegion().begin());
520 rewriter
.create
<memref::AllocaScopeReturnOp
>(loc
, ValueRange());
524 // Load loop results.
525 SmallVector
<Value
> results
;
526 results
.reserve(reductionVariables
.size());
527 for (auto [variable
, type
] :
528 llvm::zip(reductionVariables
, parallelOp
.getResultTypes())) {
529 Value res
= rewriter
.create
<LLVM::LoadOp
>(loc
, type
, variable
);
530 results
.push_back(res
);
532 rewriter
.replaceOp(parallelOp
, results
);
538 /// Applies the conversion patterns in the given function.
539 static LogicalResult
applyPatterns(ModuleOp module
, unsigned numThreads
) {
540 ConversionTarget
target(*module
.getContext());
541 target
.addIllegalOp
<scf::ReduceOp
, scf::ReduceReturnOp
, scf::ParallelOp
>();
542 target
.addLegalDialect
<omp::OpenMPDialect
, LLVM::LLVMDialect
,
543 memref::MemRefDialect
>();
545 RewritePatternSet
patterns(module
.getContext());
546 patterns
.add
<ParallelOpLowering
>(module
.getContext(), numThreads
);
547 FrozenRewritePatternSet
frozen(std::move(patterns
));
548 return applyPartialConversion(module
, target
, frozen
);
551 /// A pass converting SCF operations to OpenMP operations.
552 struct SCFToOpenMPPass
553 : public impl::ConvertSCFToOpenMPPassBase
<SCFToOpenMPPass
> {
557 /// Pass entry point.
558 void runOnOperation() override
{
559 if (failed(applyPatterns(getOperation(), numThreads
)))