1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/BoxValue.h"
10 #include "flang/Optimizer/Builder/FIRBuilder.h"
11 #include "flang/Optimizer/Builder/Factory.h"
12 #include "flang/Optimizer/Builder/Runtime/Derived.h"
13 #include "flang/Optimizer/Builder/Todo.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/Transforms/Passes.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/Support/Debug.h"
24 #define GEN_PASS_DEF_ARRAYVALUECOPY
25 #include "flang/Optimizer/Transforms/Passes.h.inc"
28 #define DEBUG_TYPE "flang-array-value-copy"
33 using OperationUseMapT
= llvm::DenseMap
<mlir::Operation
*, mlir::Operation
*>;
37 /// Array copy analysis.
38 /// Perform an interference analysis between array values.
40 /// Lowering will generate a sequence of the following form.
42 /// %a_1 = fir.array_load %array_1(%shape) : ...
44 /// %a_j = fir.array_load %array_j(%shape) : ...
46 /// %a_n = fir.array_load %array_n(%shape) : ...
48 /// %v_i = fir.array_fetch %a_i, ...
49 /// %a_j1 = fir.array_update %a_j, ...
51 /// fir.array_merge_store %a_j, %a_jn to %array_j : ...
54 /// The analysis is to determine if there are any conflicts. A conflict is when
55 /// one the following cases occurs.
57 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
58 /// loaded from the same array memory reference (array_j) but with a different
59 /// shape as the other array values a_i, where i != j. [Possible overlapping
62 /// 2. There is either an array_fetch or array_update of a_j with a different
63 /// set of index values. [Possible loop-carried dependence.]
65 /// If none of the array values overlap in storage and the accesses are not
66 /// loop-carried, then the arrays are conflict-free and no copies are required.
67 class ArrayCopyAnalysisBase
{
69 using ConflictSetT
= llvm::SmallPtrSet
<mlir::Operation
*, 16>;
70 using UseSetT
= llvm::SmallPtrSet
<mlir::OpOperand
*, 8>;
71 using LoadMapSetsT
= llvm::DenseMap
<mlir::Operation
*, UseSetT
>;
72 using AmendAccessSetT
= llvm::SmallPtrSet
<mlir::Operation
*, 4>;
74 ArrayCopyAnalysisBase(mlir::Operation
*op
, bool optimized
)
75 : operation
{op
}, optimizeConflicts(optimized
) {
78 virtual ~ArrayCopyAnalysisBase() = default;
80 mlir::Operation
*getOperation() const { return operation
; }
82 /// Return true iff the `array_merge_store` has potential conflicts.
83 bool hasPotentialConflict(mlir::Operation
*op
) const {
84 LLVM_DEBUG(llvm::dbgs()
85 << "looking for a conflict on " << *op
86 << " and the set has a total of " << conflicts
.size() << '\n');
87 return conflicts
.contains(op
);
90 /// Return the use map.
91 /// The use map maps array access, amend, fetch and update operations back to
92 /// the array load that is the original source of the array value.
93 /// It maps an array_load to an array_merge_store, if and only if the loaded
94 /// array value has pending modifications to be merged.
95 const OperationUseMapT
&getUseMap() const { return useMap
; }
97 /// Return the set of array_access ops directly associated with array_amend
99 bool inAmendAccessSet(mlir::Operation
*op
) const {
100 return amendAccesses
.count(op
);
103 /// For ArrayLoad `load`, return the transitive set of all OpOperands.
104 UseSetT
getLoadUseSet(mlir::Operation
*load
) const {
105 assert(loadMapSets
.count(load
) && "analysis missed an array load?");
106 return loadMapSets
.lookup(load
);
109 void arrayMentions(llvm::SmallVectorImpl
<mlir::Operation
*> &mentions
,
113 void construct(mlir::Operation
*topLevelOp
);
115 mlir::Operation
*operation
; // operation that analysis ran upon
116 ConflictSetT conflicts
; // set of conflicts (loads and merge stores)
117 OperationUseMapT useMap
;
118 LoadMapSetsT loadMapSets
;
119 // Set of array_access ops associated with array_amend ops.
120 AmendAccessSetT amendAccesses
;
121 bool optimizeConflicts
;
124 // Optimized array copy analysis that takes into account Fortran
125 // variable attributes to prove that no conflict is possible
126 // and reduce the number of temporary arrays.
127 class ArrayCopyAnalysisOptimized
: public ArrayCopyAnalysisBase
{
129 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysisOptimized
)
131 ArrayCopyAnalysisOptimized(mlir::Operation
*op
)
132 : ArrayCopyAnalysisBase(op
, /*optimized=*/true) {}
135 // Unoptimized array copy analysis used at O0.
136 class ArrayCopyAnalysis
: public ArrayCopyAnalysisBase
{
138 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis
)
140 ArrayCopyAnalysis(mlir::Operation
*op
)
141 : ArrayCopyAnalysisBase(op
, /*optimized=*/false) {}
146 /// Helper class to collect all array operations that produced an array value.
147 class ReachCollector
{
149 ReachCollector(llvm::SmallVectorImpl
<mlir::Operation
*> &reach
,
150 mlir::Region
*loopRegion
)
151 : reach
{reach
}, loopRegion
{loopRegion
} {}
153 void collectArrayMentionFrom(mlir::Operation
*op
, mlir::ValueRange range
) {
155 collectArrayMentionFrom(op
, mlir::Value
{});
158 for (mlir::Value v
: range
)
159 collectArrayMentionFrom(v
);
162 // Collect all the array_access ops in `block`. This recursively looks into
163 // blocks in ops with regions.
164 // FIXME: This is temporarily relying on the array_amend appearing in a
165 // do_loop Region. This phase ordering assumption can be eliminated by using
166 // dominance information to find the array_access ops or by scanning the
167 // transitive closure of the amending array_access's users and the defs that
169 void collectAccesses(llvm::SmallVector
<ArrayAccessOp
> &result
,
170 mlir::Block
*block
) {
171 for (auto &op
: *block
) {
172 if (auto access
= mlir::dyn_cast
<ArrayAccessOp
>(op
)) {
173 LLVM_DEBUG(llvm::dbgs() << "adding access: " << access
<< '\n');
174 result
.push_back(access
);
177 for (auto ®ion
: op
.getRegions())
178 for (auto &bb
: region
.getBlocks())
179 collectAccesses(result
, &bb
);
183 void collectArrayMentionFrom(mlir::Operation
*op
, mlir::Value val
) {
184 // `val` is defined by an Op, process the defining Op.
185 // If `val` is defined by a region containing Op, we want to drill down
186 // and through that Op's region(s).
187 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op
<< '\n');
188 auto popFn
= [&](auto rop
) {
189 assert(val
&& "op must have a result value");
190 auto resNum
= mlir::cast
<mlir::OpResult
>(val
).getResultNumber();
191 llvm::SmallVector
<mlir::Value
> results
;
192 rop
.resultToSourceOps(results
, resNum
);
193 for (auto u
: results
)
194 collectArrayMentionFrom(u
);
196 if (auto rop
= mlir::dyn_cast
<DoLoopOp
>(op
)) {
200 if (auto rop
= mlir::dyn_cast
<IterWhileOp
>(op
)) {
204 if (auto rop
= mlir::dyn_cast
<fir::IfOp
>(op
)) {
208 if (auto box
= mlir::dyn_cast
<EmboxOp
>(op
)) {
209 for (auto *user
: box
.getMemref().getUsers())
211 collectArrayMentionFrom(user
, user
->getResults());
214 if (auto mergeStore
= mlir::dyn_cast
<ArrayMergeStoreOp
>(op
)) {
215 if (opIsInsideLoops(mergeStore
))
216 collectArrayMentionFrom(mergeStore
.getSequence());
220 if (mlir::isa
<AllocaOp
, AllocMemOp
>(op
)) {
221 // Look for any stores inside the loops, and collect an array operation
222 // that produced the value being stored to it.
223 for (auto *user
: op
->getUsers())
224 if (auto store
= mlir::dyn_cast
<fir::StoreOp
>(user
))
225 if (opIsInsideLoops(store
))
226 collectArrayMentionFrom(store
.getValue());
230 // Scan the uses of amend's memref
231 if (auto amend
= mlir::dyn_cast
<ArrayAmendOp
>(op
)) {
233 llvm::SmallVector
<ArrayAccessOp
> accesses
;
234 collectAccesses(accesses
, op
->getBlock());
235 for (auto access
: accesses
)
236 collectArrayMentionFrom(access
.getResult());
239 // Otherwise, Op does not contain a region so just chase its operands.
240 if (mlir::isa
<ArrayAccessOp
, ArrayLoadOp
, ArrayUpdateOp
, ArrayModifyOp
,
242 LLVM_DEBUG(llvm::dbgs() << "add " << *op
<< " to reachable set\n");
246 // Include all array_access ops using an array_load.
247 if (auto arrLd
= mlir::dyn_cast
<ArrayLoadOp
>(op
))
248 for (auto *user
: arrLd
.getResult().getUsers())
249 if (mlir::isa
<ArrayAccessOp
>(user
)) {
250 LLVM_DEBUG(llvm::dbgs() << "add " << *user
<< " to reachable set\n");
251 reach
.push_back(user
);
254 // Array modify assignment is performed on the result. So the analysis must
255 // look at the what is done with the result.
256 if (mlir::isa
<ArrayModifyOp
>(op
))
257 for (auto *user
: op
->getResult(0).getUsers())
260 if (mlir::isa
<fir::CallOp
>(op
)) {
261 LLVM_DEBUG(llvm::dbgs() << "add " << *op
<< " to reachable set\n");
265 for (auto u
: op
->getOperands())
266 collectArrayMentionFrom(u
);
269 void collectArrayMentionFrom(mlir::BlockArgument ba
) {
270 auto *parent
= ba
.getOwner()->getParentOp();
271 // If inside an Op holding a region, the block argument corresponds to an
272 // argument passed to the containing Op.
273 auto popFn
= [&](auto rop
) {
274 collectArrayMentionFrom(rop
.blockArgToSourceOp(ba
.getArgNumber()));
276 if (auto rop
= mlir::dyn_cast
<DoLoopOp
>(parent
)) {
280 if (auto rop
= mlir::dyn_cast
<IterWhileOp
>(parent
)) {
284 // Otherwise, a block argument is provided via the pred blocks.
285 for (auto *pred
: ba
.getOwner()->getPredecessors()) {
286 auto u
= pred
->getTerminator()->getOperand(ba
.getArgNumber());
287 collectArrayMentionFrom(u
);
291 // Recursively trace operands to find all array operations relating to the
293 void collectArrayMentionFrom(mlir::Value val
) {
294 if (!val
|| visited
.contains(val
))
298 // Process a block argument.
299 if (auto ba
= mlir::dyn_cast
<mlir::BlockArgument
>(val
)) {
300 collectArrayMentionFrom(ba
);
305 if (auto *op
= val
.getDefiningOp()) {
306 collectArrayMentionFrom(op
, val
);
310 emitFatalError(val
.getLoc(), "unhandled value");
313 /// Return all ops that produce the array value that is stored into the
314 /// `array_merge_store`.
315 static void reachingValues(llvm::SmallVectorImpl
<mlir::Operation
*> &reach
,
318 mlir::Region
*loopRegion
= nullptr;
319 if (auto doLoop
= mlir::dyn_cast_or_null
<DoLoopOp
>(seq
.getDefiningOp()))
320 loopRegion
= &doLoop
->getRegion(0);
321 ReachCollector
collector(reach
, loopRegion
);
322 collector
.collectArrayMentionFrom(seq
);
326 /// Is \op inside the loop nest region ?
327 /// FIXME: replace this structural dependence with graph properties.
328 bool opIsInsideLoops(mlir::Operation
*op
) const {
329 auto *region
= op
->getParentRegion();
331 if (region
== loopRegion
)
333 region
= region
->getParentRegion();
338 /// Recursively trace the use of an operation results, calling
339 /// collectArrayMentionFrom on the direct and indirect user operands.
340 void followUsers(mlir::Operation
*op
) {
341 for (auto userOperand
: op
->getOperands())
342 collectArrayMentionFrom(userOperand
);
343 // Go through potential converts/coordinate_op.
344 for (auto indirectUser
: op
->getUsers())
345 followUsers(indirectUser
);
348 llvm::SmallVectorImpl
<mlir::Operation
*> &reach
;
349 llvm::SmallPtrSet
<mlir::Value
, 16> visited
;
350 /// Region of the loops nest that produced the array value.
351 mlir::Region
*loopRegion
;
355 /// Find all the array operations that access the array value that is loaded by
356 /// the array load operation, `load`.
357 void ArrayCopyAnalysisBase::arrayMentions(
358 llvm::SmallVectorImpl
<mlir::Operation
*> &mentions
, ArrayLoadOp load
) {
360 auto lmIter
= loadMapSets
.find(load
);
361 if (lmIter
!= loadMapSets
.end()) {
362 for (auto *opnd
: lmIter
->second
) {
363 auto *owner
= opnd
->getOwner();
364 if (mlir::isa
<ArrayAccessOp
, ArrayAmendOp
, ArrayFetchOp
, ArrayUpdateOp
,
365 ArrayModifyOp
>(owner
))
366 mentions
.push_back(owner
);
372 llvm::SmallVector
<mlir::OpOperand
*> queue
; // uses of ArrayLoad[orig]
374 auto appendToQueue
= [&](mlir::Value val
) {
375 for (auto &use
: val
.getUses())
376 if (!visited
.count(&use
)) {
377 visited
.insert(&use
);
378 queue
.push_back(&use
);
382 // Build the set of uses of `original`.
383 // let USES = { uses of original fir.load }
386 // Process the worklist until done.
387 while (!queue
.empty()) {
388 mlir::OpOperand
*operand
= queue
.pop_back_val();
389 mlir::Operation
*owner
= operand
->getOwner();
392 auto structuredLoop
= [&](auto ro
) {
393 if (auto blockArg
= ro
.iterArgToBlockArg(operand
->get())) {
394 int64_t arg
= blockArg
.getArgNumber();
395 mlir::Value output
= ro
.getResult(ro
.getFinalValue() ? arg
: arg
- 1);
396 appendToQueue(output
);
397 appendToQueue(blockArg
);
400 // TODO: this need to be updated to use the control-flow interface.
401 auto branchOp
= [&](mlir::Block
*dest
, OperandRange operands
) {
402 if (operands
.empty())
405 // Check if this operand is within the range.
406 unsigned operandIndex
= operand
->getOperandNumber();
407 unsigned operandsStart
= operands
.getBeginOperandIndex();
408 if (operandIndex
< operandsStart
||
409 operandIndex
>= (operandsStart
+ operands
.size()))
412 // Index the successor.
413 unsigned argIndex
= operandIndex
- operandsStart
;
414 appendToQueue(dest
->getArgument(argIndex
));
416 // Thread uses into structured loop bodies and return value uses.
417 if (auto ro
= mlir::dyn_cast
<DoLoopOp
>(owner
)) {
419 } else if (auto ro
= mlir::dyn_cast
<IterWhileOp
>(owner
)) {
421 } else if (auto rs
= mlir::dyn_cast
<ResultOp
>(owner
)) {
422 // Thread any uses of fir.if that return the marked array value.
423 mlir::Operation
*parent
= rs
->getParentRegion()->getParentOp();
424 if (auto ifOp
= mlir::dyn_cast
<fir::IfOp
>(parent
))
425 appendToQueue(ifOp
.getResult(operand
->getOperandNumber()));
426 } else if (mlir::isa
<ArrayFetchOp
>(owner
)) {
427 // Keep track of array value fetches.
428 LLVM_DEBUG(llvm::dbgs()
429 << "add fetch {" << *owner
<< "} to array value set\n");
430 mentions
.push_back(owner
);
431 } else if (auto update
= mlir::dyn_cast
<ArrayUpdateOp
>(owner
)) {
432 // Keep track of array value updates and thread the return value uses.
433 LLVM_DEBUG(llvm::dbgs()
434 << "add update {" << *owner
<< "} to array value set\n");
435 mentions
.push_back(owner
);
436 appendToQueue(update
.getResult());
437 } else if (auto update
= mlir::dyn_cast
<ArrayModifyOp
>(owner
)) {
438 // Keep track of array value modification and thread the return value
440 LLVM_DEBUG(llvm::dbgs()
441 << "add modify {" << *owner
<< "} to array value set\n");
442 mentions
.push_back(owner
);
443 appendToQueue(update
.getResult(1));
444 } else if (auto mention
= mlir::dyn_cast
<ArrayAccessOp
>(owner
)) {
445 mentions
.push_back(owner
);
446 } else if (auto amend
= mlir::dyn_cast
<ArrayAmendOp
>(owner
)) {
447 mentions
.push_back(owner
);
448 appendToQueue(amend
.getResult());
449 } else if (auto br
= mlir::dyn_cast
<mlir::cf::BranchOp
>(owner
)) {
450 branchOp(br
.getDest(), br
.getDestOperands());
451 } else if (auto br
= mlir::dyn_cast
<mlir::cf::CondBranchOp
>(owner
)) {
452 branchOp(br
.getTrueDest(), br
.getTrueOperands());
453 branchOp(br
.getFalseDest(), br
.getFalseOperands());
454 } else if (mlir::isa
<ArrayMergeStoreOp
>(owner
)) {
457 llvm::report_fatal_error("array value reached unexpected op");
460 loadMapSets
.insert({load
, visited
});
463 static bool hasPointerType(mlir::Type type
) {
464 if (auto boxTy
= mlir::dyn_cast
<BoxType
>(type
))
465 type
= boxTy
.getEleTy();
466 return mlir::isa
<fir::PointerType
>(type
);
469 // This is a NF performance hack. It makes a simple test that the slices of the
470 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive.
471 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld
, ArrayMergeStoreOp st
) {
472 // If the same array_load, then no further testing is warranted.
473 if (ld
.getResult() == st
.getOriginal())
476 auto getSliceOp
= [](mlir::Value val
) -> SliceOp
{
479 auto sliceOp
= mlir::dyn_cast_or_null
<SliceOp
>(val
.getDefiningOp());
485 auto ldSlice
= getSliceOp(ld
.getSlice());
486 auto stSlice
= getSliceOp(st
.getSlice());
487 if (!ldSlice
|| !stSlice
)
490 // Resign on subobject slices.
491 if (!ldSlice
.getFields().empty() || !stSlice
.getFields().empty() ||
492 !ldSlice
.getSubstr().empty() || !stSlice
.getSubstr().empty())
495 // Crudely test that the two slices do not overlap by looking for the
496 // following general condition. If the slices look like (i:j) and (j+1:k) then
497 // these ranges do not overlap. The addend must be a constant.
498 auto ldTriples
= ldSlice
.getTriples();
499 auto stTriples
= stSlice
.getTriples();
500 const auto size
= ldTriples
.size();
501 if (size
!= stTriples
.size())
504 auto displacedByConstant
= [](mlir::Value v1
, mlir::Value v2
) {
505 auto removeConvert
= [](mlir::Value v
) -> mlir::Operation
* {
506 auto *op
= v
.getDefiningOp();
507 while (auto conv
= mlir::dyn_cast_or_null
<ConvertOp
>(op
))
508 op
= conv
.getValue().getDefiningOp();
512 auto isPositiveConstant
= [](mlir::Value v
) -> bool {
514 mlir::dyn_cast
<mlir::arith::ConstantOp
>(v
.getDefiningOp()))
515 if (auto iattr
= mlir::dyn_cast
<mlir::IntegerAttr
>(conOp
.getValue()))
516 return iattr
.getInt() > 0;
520 auto *op1
= removeConvert(v1
);
521 auto *op2
= removeConvert(v2
);
524 if (auto addi
= mlir::dyn_cast
<mlir::arith::AddIOp
>(op2
))
525 if ((addi
.getLhs().getDefiningOp() == op1
&&
526 isPositiveConstant(addi
.getRhs())) ||
527 (addi
.getRhs().getDefiningOp() == op1
&&
528 isPositiveConstant(addi
.getLhs())))
530 if (auto subi
= mlir::dyn_cast
<mlir::arith::SubIOp
>(op1
))
531 if (subi
.getLhs().getDefiningOp() == op2
&&
532 isPositiveConstant(subi
.getRhs()))
537 for (std::remove_const_t
<decltype(size
)> i
= 0; i
< size
; i
+= 3) {
538 // If both are loop invariant, skip to the next triple.
539 if (mlir::isa_and_nonnull
<fir::UndefOp
>(ldTriples
[i
+ 1].getDefiningOp()) &&
540 mlir::isa_and_nonnull
<fir::UndefOp
>(stTriples
[i
+ 1].getDefiningOp())) {
541 // Unless either is a vector index, then be conservative.
542 if (mlir::isa_and_nonnull
<fir::UndefOp
>(ldTriples
[i
].getDefiningOp()) ||
543 mlir::isa_and_nonnull
<fir::UndefOp
>(stTriples
[i
].getDefiningOp()))
547 // If identical, skip to the next triple.
548 if (ldTriples
[i
] == stTriples
[i
] && ldTriples
[i
+ 1] == stTriples
[i
+ 1] &&
549 ldTriples
[i
+ 2] == stTriples
[i
+ 2])
551 // If ubound and lbound are the same with a constant offset, skip to the
553 if (displacedByConstant(ldTriples
[i
+ 1], stTriples
[i
]) ||
554 displacedByConstant(stTriples
[i
+ 1], ldTriples
[i
]))
558 LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld
559 << " and " << st
<< ", which is not a conflict\n");
563 /// Is there a conflict between the array value that was updated and to be
564 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
565 /// the updated value?
566 /// If `optimize` is true, use the variable attributes to prove that
567 /// there is no conflict.
568 static bool conflictOnLoad(llvm::ArrayRef
<mlir::Operation
*> reach
,
569 ArrayMergeStoreOp st
, bool optimize
) {
571 mlir::Value addr
= st
.getMemref();
572 const bool storeHasPointerType
= hasPointerType(addr
.getType());
573 for (auto *op
: reach
)
574 if (auto ld
= mlir::dyn_cast
<ArrayLoadOp
>(op
)) {
575 mlir::Type ldTy
= ld
.getMemref().getType();
576 auto globalOpName
= mlir::OperationName(fir::GlobalOp::getOperationName(),
578 if (ld
.getMemref() == addr
) {
579 if (mutuallyExclusiveSliceRange(ld
, st
))
581 if (ld
.getResult() != st
.getOriginal())
584 // TODO: extend this to allow checking if the first `load` and this
585 // `ld` are mutually exclusive accesses but not identical.
589 } else if (storeHasPointerType
) {
590 if (optimize
&& !hasPointerType(ldTy
) &&
591 !valueMayHaveFirAttributes(
593 {getTargetAttrName(),
594 fir::GlobalOp::getTargetAttrName(globalOpName
).strref()}))
598 } else if (hasPointerType(ldTy
)) {
599 if (optimize
&& !storeHasPointerType
&&
600 !valueMayHaveFirAttributes(
602 {getTargetAttrName(),
603 fir::GlobalOp::getTargetAttrName(globalOpName
).strref()}))
608 // TODO: Check if types can also allow ruling out some cases. For now,
609 // the fact that equivalences is using pointer attribute to enforce
610 // aliasing is preventing any attempt to do so, and in general, it may
611 // be wrong to use this if any of the types is a complex or a derived
612 // for which it is possible to create a pointer to a part with a
613 // different type than the whole, although this deserve some more
614 // investigation because existing compiler behavior seem to diverge
620 /// Is there an access vector conflict on the array being merged into? If the
621 /// access vectors diverge, then assume that there are potentially overlapping
622 /// loop-carried references.
623 static bool conflictOnMerge(llvm::ArrayRef
<mlir::Operation
*> mentions
) {
624 if (mentions
.size() < 2)
626 llvm::SmallVector
<mlir::Value
> indices
;
627 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions
.size()
628 << " mentions on the list\n");
629 bool valSeen
= false;
630 bool refSeen
= false;
631 for (auto *op
: mentions
) {
632 llvm::SmallVector
<mlir::Value
> compareVector
;
633 if (auto u
= mlir::dyn_cast
<ArrayUpdateOp
>(op
)) {
635 if (indices
.empty()) {
636 indices
= u
.getIndices();
639 compareVector
= u
.getIndices();
640 } else if (auto f
= mlir::dyn_cast
<ArrayModifyOp
>(op
)) {
642 if (indices
.empty()) {
643 indices
= f
.getIndices();
646 compareVector
= f
.getIndices();
647 } else if (auto f
= mlir::dyn_cast
<ArrayFetchOp
>(op
)) {
649 if (indices
.empty()) {
650 indices
= f
.getIndices();
653 compareVector
= f
.getIndices();
654 } else if (auto f
= mlir::dyn_cast
<ArrayAccessOp
>(op
)) {
656 if (indices
.empty()) {
657 indices
= f
.getIndices();
660 compareVector
= f
.getIndices();
661 } else if (mlir::isa
<ArrayAmendOp
>(op
)) {
665 mlir::emitError(op
->getLoc(), "unexpected operation in analysis");
667 if (compareVector
.size() != indices
.size() ||
668 llvm::any_of(llvm::zip(compareVector
, indices
), [&](auto pair
) {
669 return std::get
<0>(pair
) != std::get
<1>(pair
);
672 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
674 return valSeen
&& refSeen
;
677 /// With element-by-reference semantics, an amended array with more than once
678 /// access to the same loaded array are conservatively considered a conflict.
679 /// Note: the array copy can still be eliminated in subsequent optimizations.
680 static bool conflictOnReference(llvm::ArrayRef
<mlir::Operation
*> mentions
) {
681 LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions
.size()
683 if (mentions
.size() < 3)
685 unsigned amendCount
= 0;
686 unsigned accessCount
= 0;
687 for (auto *op
: mentions
) {
688 if (mlir::isa
<ArrayAmendOp
>(op
) && ++amendCount
> 1) {
689 LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n");
692 if (mlir::isa
<ArrayAccessOp
>(op
) && ++accessCount
> 1) {
693 LLVM_DEBUG(llvm::dbgs()
694 << "conflict: multiple accesses of array value\n");
697 if (mlir::isa
<ArrayFetchOp
, ArrayUpdateOp
, ArrayModifyOp
>(op
)) {
698 LLVM_DEBUG(llvm::dbgs()
699 << "conflict: array value has both uses by-value and uses "
700 "by-reference. conservative assumption.\n");
707 static mlir::Operation
*
708 amendingAccess(llvm::ArrayRef
<mlir::Operation
*> mentions
) {
709 for (auto *op
: mentions
)
710 if (auto amend
= mlir::dyn_cast
<ArrayAmendOp
>(op
))
711 return amend
.getMemref().getDefiningOp();
715 // Are any conflicts present? The conflicts detected here are described above.
716 static bool conflictDetected(llvm::ArrayRef
<mlir::Operation
*> reach
,
717 llvm::ArrayRef
<mlir::Operation
*> mentions
,
718 ArrayMergeStoreOp st
, bool optimize
) {
719 return conflictOnLoad(reach
, st
, optimize
) || conflictOnMerge(mentions
);
722 // Assume that any call to a function that uses host-associations will be
723 // modifying the output array.
725 conservativeCallConflict(llvm::ArrayRef
<mlir::Operation
*> reaches
) {
726 return llvm::any_of(reaches
, [](mlir::Operation
*op
) {
727 if (auto call
= mlir::dyn_cast
<fir::CallOp
>(op
))
728 if (auto callee
= mlir::dyn_cast
<mlir::SymbolRefAttr
>(
729 call
.getCallableForCallee())) {
730 auto module
= op
->getParentOfType
<mlir::ModuleOp
>();
731 return isInternalProcedure(
732 module
.lookupSymbol
<mlir::func::FuncOp
>(callee
));
738 /// Constructor of the array copy analysis.
739 /// This performs the analysis and saves the intermediate results.
740 void ArrayCopyAnalysisBase::construct(mlir::Operation
*topLevelOp
) {
741 topLevelOp
->walk([&](Operation
*op
) {
742 if (auto st
= mlir::dyn_cast
<fir::ArrayMergeStoreOp
>(op
)) {
743 llvm::SmallVector
<mlir::Operation
*> values
;
744 ReachCollector::reachingValues(values
, st
.getSequence());
745 bool callConflict
= conservativeCallConflict(values
);
746 llvm::SmallVector
<mlir::Operation
*> mentions
;
747 arrayMentions(mentions
,
748 mlir::cast
<ArrayLoadOp
>(st
.getOriginal().getDefiningOp()));
749 bool conflict
= conflictDetected(values
, mentions
, st
, optimizeConflicts
);
750 bool refConflict
= conflictOnReference(mentions
);
751 if (callConflict
|| conflict
|| refConflict
) {
752 LLVM_DEBUG(llvm::dbgs()
753 << "CONFLICT: copies required for " << st
<< '\n'
754 << " adding conflicts on: " << *op
<< " and "
755 << st
.getOriginal() << '\n');
756 conflicts
.insert(op
);
757 conflicts
.insert(st
.getOriginal().getDefiningOp());
758 if (auto *access
= amendingAccess(mentions
))
759 amendAccesses
.insert(access
);
761 auto *ld
= st
.getOriginal().getDefiningOp();
762 LLVM_DEBUG(llvm::dbgs()
763 << "map: adding {" << *ld
<< " -> " << st
<< "}\n");
764 useMap
.insert({ld
, op
});
765 } else if (auto load
= mlir::dyn_cast
<ArrayLoadOp
>(op
)) {
766 llvm::SmallVector
<mlir::Operation
*> mentions
;
767 arrayMentions(mentions
, load
);
768 LLVM_DEBUG(llvm::dbgs() << "process load: " << load
769 << ", mentions: " << mentions
.size() << '\n');
770 for (auto *acc
: mentions
) {
771 LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc
<< '\n');
772 if (mlir::isa
<ArrayAccessOp
, ArrayAmendOp
, ArrayFetchOp
, ArrayUpdateOp
,
773 ArrayModifyOp
>(acc
)) {
774 if (useMap
.count(acc
)) {
777 "The parallel semantics of multiple array_merge_stores per "
778 "array_load are not supported.");
781 LLVM_DEBUG(llvm::dbgs()
782 << "map: adding {" << *acc
<< "} -> {" << load
<< "}\n");
783 useMap
.insert({acc
, op
});
790 //===----------------------------------------------------------------------===//
791 // Conversions for converting out of array value form.
792 //===----------------------------------------------------------------------===//
795 class ArrayLoadConversion
: public mlir::OpRewritePattern
<ArrayLoadOp
> {
797 using OpRewritePattern::OpRewritePattern
;
800 matchAndRewrite(ArrayLoadOp load
,
801 mlir::PatternRewriter
&rewriter
) const override
{
802 LLVM_DEBUG(llvm::dbgs() << "replace load " << load
<< " with undef.\n");
803 rewriter
.replaceOpWithNewOp
<UndefOp
>(load
, load
.getType());
804 return mlir::success();
808 class ArrayMergeStoreConversion
809 : public mlir::OpRewritePattern
<ArrayMergeStoreOp
> {
811 using OpRewritePattern::OpRewritePattern
;
814 matchAndRewrite(ArrayMergeStoreOp store
,
815 mlir::PatternRewriter
&rewriter
) const override
{
816 LLVM_DEBUG(llvm::dbgs() << "marking store " << store
<< " as dead.\n");
817 rewriter
.eraseOp(store
);
818 return mlir::success();
823 static mlir::Type
getEleTy(mlir::Type ty
) {
824 auto eleTy
= unwrapSequenceType(unwrapPassByRefType(ty
));
825 // FIXME: keep ptr/heap/ref information.
826 return ReferenceType::get(eleTy
);
829 // This is an unsafe way to deduce this (won't be true in internal
830 // procedure or inside select-rank for assumed-size). Only here to satisfy
831 // legacy code until removed.
832 static bool isAssumedSize(llvm::SmallVectorImpl
<mlir::Value
> &extents
) {
835 auto cstLen
= fir::getIntIfConstant(extents
.back());
836 return cstLen
.has_value() && *cstLen
== -1;
839 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
840 static bool getAdjustedExtents(mlir::Location loc
,
841 mlir::PatternRewriter
&rewriter
,
843 llvm::SmallVectorImpl
<mlir::Value
> &result
,
845 bool copyUsingSlice
= false;
846 auto *shapeOp
= shape
.getDefiningOp();
847 if (auto s
= mlir::dyn_cast_or_null
<ShapeOp
>(shapeOp
)) {
848 auto e
= s
.getExtents();
849 result
.insert(result
.end(), e
.begin(), e
.end());
850 } else if (auto s
= mlir::dyn_cast_or_null
<ShapeShiftOp
>(shapeOp
)) {
851 auto e
= s
.getExtents();
852 result
.insert(result
.end(), e
.begin(), e
.end());
854 emitFatalError(loc
, "not a fir.shape/fir.shape_shift op");
856 auto idxTy
= rewriter
.getIndexType();
857 if (isAssumedSize(result
)) {
858 // Use slice information to compute the extent of the column.
859 auto one
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, 1);
860 mlir::Value size
= one
;
861 if (mlir::Value sliceArg
= arrLoad
.getSlice()) {
863 mlir::dyn_cast_or_null
<SliceOp
>(sliceArg
.getDefiningOp())) {
864 auto triples
= sliceOp
.getTriples();
865 const std::size_t tripleSize
= triples
.size();
866 auto module
= arrLoad
->getParentOfType
<mlir::ModuleOp
>();
867 FirOpBuilder
builder(rewriter
, module
);
868 size
= builder
.genExtentFromTriplet(loc
, triples
[tripleSize
- 3],
869 triples
[tripleSize
- 2],
870 triples
[tripleSize
- 1], idxTy
);
871 copyUsingSlice
= true;
874 result
[result
.size() - 1] = size
;
876 return copyUsingSlice
;
879 /// Place the extents of the array load, \p arrLoad, into \p result and
880 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is
881 /// loading a `!fir.box`, code will be generated to read the extents from the
882 /// boxed value, and the retunred shape Op will be built with the extents read
883 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or
884 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true
885 /// if slicing of the output array is to be done in the copy-in/copy-out rather
886 /// than in the elemental computation step.
887 static mlir::Value
getOrReadExtentsAndShapeOp(
888 mlir::Location loc
, mlir::PatternRewriter
&rewriter
, ArrayLoadOp arrLoad
,
889 llvm::SmallVectorImpl
<mlir::Value
> &result
, bool ©UsingSlice
) {
890 assert(result
.empty());
891 if (arrLoad
->hasAttr(fir::getOptionalAttrName()))
893 loc
, "shapes from array load of OPTIONAL arrays must not be used");
894 if (auto boxTy
= mlir::dyn_cast
<BoxType
>(arrLoad
.getMemref().getType())) {
896 mlir::cast
<SequenceType
>(dyn_cast_ptrOrBoxEleTy(boxTy
)).getDimension();
897 auto idxTy
= rewriter
.getIndexType();
898 for (decltype(rank
) dim
= 0; dim
< rank
; ++dim
) {
899 auto dimVal
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, dim
);
900 auto dimInfo
= rewriter
.create
<BoxDimsOp
>(loc
, idxTy
, idxTy
, idxTy
,
901 arrLoad
.getMemref(), dimVal
);
902 result
.emplace_back(dimInfo
.getResult(1));
904 if (!arrLoad
.getShape()) {
905 auto shapeType
= ShapeType::get(rewriter
.getContext(), rank
);
906 return rewriter
.create
<ShapeOp
>(loc
, shapeType
, result
);
908 auto shiftOp
= arrLoad
.getShape().getDefiningOp
<ShiftOp
>();
909 auto shapeShiftType
= ShapeShiftType::get(rewriter
.getContext(), rank
);
910 llvm::SmallVector
<mlir::Value
> shapeShiftOperands
;
911 for (auto [lb
, extent
] : llvm::zip(shiftOp
.getOrigins(), result
)) {
912 shapeShiftOperands
.push_back(lb
);
913 shapeShiftOperands
.push_back(extent
);
915 return rewriter
.create
<ShapeShiftOp
>(loc
, shapeShiftType
,
919 getAdjustedExtents(loc
, rewriter
, arrLoad
, result
, arrLoad
.getShape());
920 return arrLoad
.getShape();
923 static mlir::Type
toRefType(mlir::Type ty
) {
924 if (fir::isa_ref_type(ty
))
926 return fir::ReferenceType::get(ty
);
929 static llvm::SmallVector
<mlir::Value
>
930 getTypeParamsIfRawData(mlir::Location loc
, FirOpBuilder
&builder
,
931 ArrayLoadOp arrLoad
, mlir::Type ty
) {
932 if (mlir::isa
<BoxType
>(ty
))
934 return fir::factory::getTypeParams(loc
, builder
, arrLoad
);
937 static mlir::Value
genCoorOp(mlir::PatternRewriter
&rewriter
,
938 mlir::Location loc
, mlir::Type eleTy
,
939 mlir::Type resTy
, mlir::Value alloc
,
940 mlir::Value shape
, mlir::Value slice
,
941 mlir::ValueRange indices
, ArrayLoadOp load
,
942 bool skipOrig
= false) {
943 llvm::SmallVector
<mlir::Value
> originated
;
945 originated
.assign(indices
.begin(), indices
.end());
947 originated
= factory::originateIndices(loc
, rewriter
, alloc
.getType(),
949 auto seqTy
= dyn_cast_ptrOrBoxEleTy(alloc
.getType());
950 assert(seqTy
&& mlir::isa
<SequenceType
>(seqTy
));
951 const auto dimension
= mlir::cast
<SequenceType
>(seqTy
).getDimension();
952 auto module
= load
->getParentOfType
<mlir::ModuleOp
>();
953 FirOpBuilder
builder(rewriter
, module
);
954 auto typeparams
= getTypeParamsIfRawData(loc
, builder
, load
, alloc
.getType());
955 mlir::Value result
= rewriter
.create
<ArrayCoorOp
>(
956 loc
, eleTy
, alloc
, shape
, slice
,
957 llvm::ArrayRef
<mlir::Value
>{originated
}.take_front(dimension
),
959 if (dimension
< originated
.size())
960 result
= rewriter
.create
<fir::CoordinateOp
>(
962 llvm::ArrayRef
<mlir::Value
>{originated
}.drop_front(dimension
));
966 static mlir::Value
getCharacterLen(mlir::Location loc
, FirOpBuilder
&builder
,
967 ArrayLoadOp load
, CharacterType charTy
) {
968 auto charLenTy
= builder
.getCharacterLengthType();
969 if (charTy
.hasDynamicLen()) {
970 if (mlir::isa
<BoxType
>(load
.getMemref().getType())) {
971 // The loaded array is an emboxed value. Get the CHARACTER length from
974 builder
.create
<BoxEleSizeOp
>(loc
, charLenTy
, load
.getMemref());
976 builder
.getKindMap().getCharacterBitsize(charTy
.getFKind());
978 builder
.createIntegerConstant(loc
, charLenTy
, kindSize
/ 8);
979 return builder
.create
<mlir::arith::DivSIOp
>(loc
, eleSzInBytes
,
982 // The loaded array is a (set of) unboxed values. If the CHARACTER's
983 // length is not a constant, it must be provided as a type parameter to
985 auto typeparams
= load
.getTypeparams();
986 assert(typeparams
.size() > 0 && "expected type parameters on array_load");
987 return typeparams
.back();
989 // The typical case: the length of the CHARACTER is a compile-time
990 // constant that is encoded in the type information.
991 return builder
.createIntegerConstant(loc
, charLenTy
, charTy
.getLen());
993 /// Generate a shallow array copy. This is used for both copy-in and copy-out.
994 template <bool CopyIn
>
995 void genArrayCopy(mlir::Location loc
, mlir::PatternRewriter
&rewriter
,
996 mlir::Value dst
, mlir::Value src
, mlir::Value shapeOp
,
997 mlir::Value sliceOp
, ArrayLoadOp arrLoad
) {
998 auto insPt
= rewriter
.saveInsertionPoint();
999 llvm::SmallVector
<mlir::Value
> indices
;
1000 llvm::SmallVector
<mlir::Value
> extents
;
1001 bool copyUsingSlice
=
1002 getAdjustedExtents(loc
, rewriter
, arrLoad
, extents
, shapeOp
);
1003 auto idxTy
= rewriter
.getIndexType();
1004 // Build loop nest from column to row.
1005 for (auto sh
: llvm::reverse(extents
)) {
1006 auto ubi
= rewriter
.create
<ConvertOp
>(loc
, idxTy
, sh
);
1007 auto zero
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, 0);
1008 auto one
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, 1);
1009 auto ub
= rewriter
.create
<mlir::arith::SubIOp
>(loc
, idxTy
, ubi
, one
);
1010 auto loop
= rewriter
.create
<DoLoopOp
>(loc
, zero
, ub
, one
);
1011 rewriter
.setInsertionPointToStart(loop
.getBody());
1012 indices
.push_back(loop
.getInductionVar());
1014 // Reverse the indices so they are in column-major order.
1015 std::reverse(indices
.begin(), indices
.end());
1016 auto module
= arrLoad
->getParentOfType
<mlir::ModuleOp
>();
1017 FirOpBuilder
builder(rewriter
, module
);
1018 auto fromAddr
= rewriter
.create
<ArrayCoorOp
>(
1019 loc
, getEleTy(src
.getType()), src
, shapeOp
,
1020 CopyIn
&& copyUsingSlice
? sliceOp
: mlir::Value
{},
1021 factory::originateIndices(loc
, rewriter
, src
.getType(), shapeOp
, indices
),
1022 getTypeParamsIfRawData(loc
, builder
, arrLoad
, src
.getType()));
1023 auto toAddr
= rewriter
.create
<ArrayCoorOp
>(
1024 loc
, getEleTy(dst
.getType()), dst
, shapeOp
,
1025 !CopyIn
&& copyUsingSlice
? sliceOp
: mlir::Value
{},
1026 factory::originateIndices(loc
, rewriter
, dst
.getType(), shapeOp
, indices
),
1027 getTypeParamsIfRawData(loc
, builder
, arrLoad
, dst
.getType()));
1028 auto eleTy
= unwrapSequenceType(unwrapPassByRefType(dst
.getType()));
1029 // Copy from (to) object to (from) temp copy of same object.
1030 if (auto charTy
= mlir::dyn_cast
<CharacterType
>(eleTy
)) {
1031 auto len
= getCharacterLen(loc
, builder
, arrLoad
, charTy
);
1032 CharBoxValue
toChar(toAddr
, len
);
1033 CharBoxValue
fromChar(fromAddr
, len
);
1034 factory::genScalarAssignment(builder
, loc
, toChar
, fromChar
);
1036 if (hasDynamicSize(eleTy
))
1037 TODO(loc
, "copy element of dynamic size");
1038 factory::genScalarAssignment(builder
, loc
, toAddr
, fromAddr
);
1040 rewriter
.restoreInsertionPoint(insPt
);
1043 /// The array load may be either a boxed or unboxed value. If the value is
1044 /// boxed, we read the type parameters from the boxed value.
1045 static llvm::SmallVector
<mlir::Value
>
1046 genArrayLoadTypeParameters(mlir::Location loc
, mlir::PatternRewriter
&rewriter
,
1048 if (load
.getTypeparams().empty()) {
1050 unwrapSequenceType(unwrapPassByRefType(load
.getMemref().getType()));
1051 if (hasDynamicSize(eleTy
)) {
1052 if (auto charTy
= mlir::dyn_cast
<CharacterType
>(eleTy
)) {
1053 assert(mlir::isa
<BoxType
>(load
.getMemref().getType()));
1054 auto module
= load
->getParentOfType
<mlir::ModuleOp
>();
1055 FirOpBuilder
builder(rewriter
, module
);
1056 return {getCharacterLen(loc
, builder
, load
, charTy
)};
1058 TODO(loc
, "unhandled dynamic type parameters");
1062 return load
.getTypeparams();
1065 static llvm::SmallVector
<mlir::Value
>
1066 findNonconstantExtents(mlir::Type memrefTy
,
1067 llvm::ArrayRef
<mlir::Value
> extents
) {
1068 llvm::SmallVector
<mlir::Value
> nce
;
1069 auto arrTy
= unwrapPassByRefType(memrefTy
);
1070 auto seqTy
= mlir::cast
<SequenceType
>(arrTy
);
1071 for (auto [s
, x
] : llvm::zip(seqTy
.getShape(), extents
))
1072 if (s
== SequenceType::getUnknownExtent())
1073 nce
.emplace_back(x
);
1074 if (extents
.size() > seqTy
.getShape().size())
1075 for (auto x
: extents
.drop_front(seqTy
.getShape().size()))
1076 nce
.emplace_back(x
);
1080 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any
1081 /// allocatable direct components of the array elements with an unallocated
1082 /// status. Returns the temporary address as well as a callback to generate the
1083 /// temporary clean-up once it has been used. The clean-up will take care of
1084 /// deallocating all the element allocatable components that may have been
1085 /// allocated while using the temporary.
1086 static std::pair
<mlir::Value
,
1087 std::function
<void(mlir::PatternRewriter
&rewriter
)>>
1088 allocateArrayTemp(mlir::Location loc
, mlir::PatternRewriter
&rewriter
,
1089 ArrayLoadOp load
, llvm::ArrayRef
<mlir::Value
> extents
,
1090 mlir::Value shape
) {
1091 mlir::Type baseType
= load
.getMemref().getType();
1092 llvm::SmallVector
<mlir::Value
> nonconstantExtents
=
1093 findNonconstantExtents(baseType
, extents
);
1094 llvm::SmallVector
<mlir::Value
> typeParams
=
1095 genArrayLoadTypeParameters(loc
, rewriter
, load
);
1096 mlir::Value allocmem
= rewriter
.create
<AllocMemOp
>(
1097 loc
, dyn_cast_ptrOrBoxEleTy(baseType
), typeParams
, nonconstantExtents
);
1098 mlir::Type eleType
=
1099 fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType
));
1100 if (fir::isRecordWithAllocatableMember(eleType
)) {
1101 // The allocatable component descriptors need to be set to a clean
1102 // deallocated status before anything is done with them.
1103 mlir::Value box
= rewriter
.create
<fir::EmboxOp
>(
1104 loc
, fir::BoxType::get(allocmem
.getType()), allocmem
, shape
,
1105 /*slice=*/mlir::Value
{}, typeParams
);
1106 auto module
= load
->getParentOfType
<mlir::ModuleOp
>();
1107 FirOpBuilder
builder(rewriter
, module
);
1108 runtime::genDerivedTypeInitialize(builder
, loc
, box
);
1109 // Any allocatable component that may have been allocated must be
1110 // deallocated during the clean-up.
1111 auto cleanup
= [=](mlir::PatternRewriter
&r
) {
1112 FirOpBuilder
builder(r
, module
);
1113 runtime::genDerivedTypeDestroy(builder
, loc
, box
);
1114 r
.create
<FreeMemOp
>(loc
, allocmem
);
1116 return {allocmem
, cleanup
};
1118 auto cleanup
= [=](mlir::PatternRewriter
&r
) {
1119 r
.create
<FreeMemOp
>(loc
, allocmem
);
1121 return {allocmem
, cleanup
};
1125 /// Conversion of fir.array_update and fir.array_modify Ops.
1126 /// If there is a conflict for the update, then we need to perform a
1127 /// copy-in/copy-out to preserve the original values of the array. If there is
1128 /// no conflict, then it is save to eschew making any copies.
1129 template <typename ArrayOp
>
1130 class ArrayUpdateConversionBase
: public mlir::OpRewritePattern
<ArrayOp
> {
1132 // TODO: Implement copy/swap semantics?
1133 explicit ArrayUpdateConversionBase(mlir::MLIRContext
*ctx
,
1134 const ArrayCopyAnalysisBase
&a
,
1135 const OperationUseMapT
&m
)
1136 : mlir::OpRewritePattern
<ArrayOp
>{ctx
}, analysis
{a
}, useMap
{m
} {}
1138 /// The array_access, \p access, is to be to a cloned copy due to a potential
1139 /// conflict. Uses copy-in/copy-out semantics and not copy/swap.
1140 mlir::Value
referenceToClone(mlir::Location loc
,
1141 mlir::PatternRewriter
&rewriter
,
1142 ArrayOp access
) const {
1143 LLVM_DEBUG(llvm::dbgs()
1144 << "generating copy-in/copy-out loops for " << access
<< '\n');
1145 auto *op
= access
.getOperation();
1146 auto *loadOp
= useMap
.lookup(op
);
1147 auto load
= mlir::cast
<ArrayLoadOp
>(loadOp
);
1148 auto eleTy
= access
.getType();
1149 rewriter
.setInsertionPoint(loadOp
);
1151 llvm::SmallVector
<mlir::Value
> extents
;
1152 bool copyUsingSlice
= false;
1153 auto shapeOp
= getOrReadExtentsAndShapeOp(loc
, rewriter
, load
, extents
,
1155 auto [allocmem
, genTempCleanUp
] =
1156 allocateArrayTemp(loc
, rewriter
, load
, extents
, shapeOp
);
1157 genArrayCopy
</*copyIn=*/true>(load
.getLoc(), rewriter
, allocmem
,
1158 load
.getMemref(), shapeOp
, load
.getSlice(),
1160 // Generate the reference for the access.
1161 rewriter
.setInsertionPoint(op
);
1162 auto coor
= genCoorOp(
1163 rewriter
, loc
, getEleTy(load
.getType()), eleTy
, allocmem
, shapeOp
,
1164 copyUsingSlice
? mlir::Value
{} : load
.getSlice(), access
.getIndices(),
1165 load
, access
->hasAttr(factory::attrFortranArrayOffsets()));
1167 auto *storeOp
= useMap
.lookup(loadOp
);
1168 auto store
= mlir::cast
<ArrayMergeStoreOp
>(storeOp
);
1169 rewriter
.setInsertionPoint(storeOp
);
1171 genArrayCopy
</*copyIn=*/false>(store
.getLoc(), rewriter
, store
.getMemref(),
1172 allocmem
, shapeOp
, store
.getSlice(), load
);
1173 genTempCleanUp(rewriter
);
1177 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
1178 /// temp and the LHS if the analysis found potential overlaps between the RHS
1179 /// and LHS arrays. The element copy generator must be provided in \p
1180 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
1181 /// Returns the address of the LHS element inside the loop and the LHS
1182 /// ArrayLoad result.
1183 std::pair
<mlir::Value
, mlir::Value
>
1184 materializeAssignment(mlir::Location loc
, mlir::PatternRewriter
&rewriter
,
1186 const std::function
<void(mlir::Value
)> &assignElement
,
1187 mlir::Type lhsEltRefType
) const {
1188 auto *op
= update
.getOperation();
1189 auto *loadOp
= useMap
.lookup(op
);
1190 auto load
= mlir::cast
<ArrayLoadOp
>(loadOp
);
1191 LLVM_DEBUG(llvm::outs() << "does " << load
<< " have a conflict?\n");
1192 if (analysis
.hasPotentialConflict(loadOp
)) {
1193 // If there is a conflict between the arrays, then we copy the lhs array
1194 // to a temporary, update the temporary, and copy the temporary back to
1195 // the lhs array. This yields Fortran's copy-in copy-out array semantics.
1196 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
1197 rewriter
.setInsertionPoint(loadOp
);
1199 llvm::SmallVector
<mlir::Value
> extents
;
1200 bool copyUsingSlice
= false;
1201 auto shapeOp
= getOrReadExtentsAndShapeOp(loc
, rewriter
, load
, extents
,
1203 auto [allocmem
, genTempCleanUp
] =
1204 allocateArrayTemp(loc
, rewriter
, load
, extents
, shapeOp
);
1206 genArrayCopy
</*copyIn=*/true>(load
.getLoc(), rewriter
, allocmem
,
1207 load
.getMemref(), shapeOp
, load
.getSlice(),
1209 rewriter
.setInsertionPoint(op
);
1210 auto coor
= genCoorOp(
1211 rewriter
, loc
, getEleTy(load
.getType()), lhsEltRefType
, allocmem
,
1212 shapeOp
, copyUsingSlice
? mlir::Value
{} : load
.getSlice(),
1213 update
.getIndices(), load
,
1214 update
->hasAttr(factory::attrFortranArrayOffsets()));
1215 assignElement(coor
);
1216 auto *storeOp
= useMap
.lookup(loadOp
);
1217 auto store
= mlir::cast
<ArrayMergeStoreOp
>(storeOp
);
1218 rewriter
.setInsertionPoint(storeOp
);
1220 genArrayCopy
</*copyIn=*/false>(store
.getLoc(), rewriter
,
1221 store
.getMemref(), allocmem
, shapeOp
,
1222 store
.getSlice(), load
);
1223 genTempCleanUp(rewriter
);
1224 return {coor
, load
.getResult()};
1226 // Otherwise, when there is no conflict (a possible loop-carried
1227 // dependence), the lhs array can be updated in place.
1228 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
1229 rewriter
.setInsertionPoint(op
);
1230 auto coorTy
= getEleTy(load
.getType());
1232 genCoorOp(rewriter
, loc
, coorTy
, lhsEltRefType
, load
.getMemref(),
1233 load
.getShape(), load
.getSlice(), update
.getIndices(), load
,
1234 update
->hasAttr(factory::attrFortranArrayOffsets()));
1235 assignElement(coor
);
1236 return {coor
, load
.getResult()};
1240 const ArrayCopyAnalysisBase
&analysis
;
1241 const OperationUseMapT
&useMap
;
1244 class ArrayUpdateConversion
: public ArrayUpdateConversionBase
<ArrayUpdateOp
> {
1246 explicit ArrayUpdateConversion(mlir::MLIRContext
*ctx
,
1247 const ArrayCopyAnalysisBase
&a
,
1248 const OperationUseMapT
&m
)
1249 : ArrayUpdateConversionBase
{ctx
, a
, m
} {}
1252 matchAndRewrite(ArrayUpdateOp update
,
1253 mlir::PatternRewriter
&rewriter
) const override
{
1254 auto loc
= update
.getLoc();
1255 auto assignElement
= [&](mlir::Value coor
) {
1256 auto input
= update
.getMerge();
1257 if (auto inEleTy
= dyn_cast_ptrEleTy(input
.getType())) {
1258 emitFatalError(loc
, "array_update on references not supported");
1260 rewriter
.create
<fir::StoreOp
>(loc
, input
, coor
);
1263 auto lhsEltRefType
= toRefType(update
.getMerge().getType());
1264 auto [_
, lhsLoadResult
] = materializeAssignment(
1265 loc
, rewriter
, update
, assignElement
, lhsEltRefType
);
1266 update
.replaceAllUsesWith(lhsLoadResult
);
1267 rewriter
.replaceOp(update
, lhsLoadResult
);
1268 return mlir::success();
1272 class ArrayModifyConversion
: public ArrayUpdateConversionBase
<ArrayModifyOp
> {
1274 explicit ArrayModifyConversion(mlir::MLIRContext
*ctx
,
1275 const ArrayCopyAnalysisBase
&a
,
1276 const OperationUseMapT
&m
)
1277 : ArrayUpdateConversionBase
{ctx
, a
, m
} {}
1280 matchAndRewrite(ArrayModifyOp modify
,
1281 mlir::PatternRewriter
&rewriter
) const override
{
1282 auto loc
= modify
.getLoc();
1283 auto assignElement
= [](mlir::Value
) {
1284 // Assignment already materialized by lowering using lhs element address.
1286 auto lhsEltRefType
= modify
.getResult(0).getType();
1287 auto [lhsEltCoor
, lhsLoadResult
] = materializeAssignment(
1288 loc
, rewriter
, modify
, assignElement
, lhsEltRefType
);
1289 modify
.replaceAllUsesWith(mlir::ValueRange
{lhsEltCoor
, lhsLoadResult
});
1290 rewriter
.replaceOp(modify
, mlir::ValueRange
{lhsEltCoor
, lhsLoadResult
});
1291 return mlir::success();
1295 class ArrayFetchConversion
: public mlir::OpRewritePattern
<ArrayFetchOp
> {
1297 explicit ArrayFetchConversion(mlir::MLIRContext
*ctx
,
1298 const OperationUseMapT
&m
)
1299 : OpRewritePattern
{ctx
}, useMap
{m
} {}
1302 matchAndRewrite(ArrayFetchOp fetch
,
1303 mlir::PatternRewriter
&rewriter
) const override
{
1304 auto *op
= fetch
.getOperation();
1305 rewriter
.setInsertionPoint(op
);
1306 auto load
= mlir::cast
<ArrayLoadOp
>(useMap
.lookup(op
));
1307 auto loc
= fetch
.getLoc();
1308 auto coor
= genCoorOp(
1309 rewriter
, loc
, getEleTy(load
.getType()), toRefType(fetch
.getType()),
1310 load
.getMemref(), load
.getShape(), load
.getSlice(), fetch
.getIndices(),
1311 load
, fetch
->hasAttr(factory::attrFortranArrayOffsets()));
1312 if (isa_ref_type(fetch
.getType()))
1313 rewriter
.replaceOp(fetch
, coor
);
1315 rewriter
.replaceOpWithNewOp
<fir::LoadOp
>(fetch
, coor
);
1316 return mlir::success();
1320 const OperationUseMapT
&useMap
;
1323 /// As array_access op is like an array_fetch op, except that it does not imply
1324 /// a load op. (It operates in the reference domain.)
1325 class ArrayAccessConversion
: public ArrayUpdateConversionBase
<ArrayAccessOp
> {
1327 explicit ArrayAccessConversion(mlir::MLIRContext
*ctx
,
1328 const ArrayCopyAnalysisBase
&a
,
1329 const OperationUseMapT
&m
)
1330 : ArrayUpdateConversionBase
{ctx
, a
, m
} {}
1333 matchAndRewrite(ArrayAccessOp access
,
1334 mlir::PatternRewriter
&rewriter
) const override
{
1335 auto *op
= access
.getOperation();
1336 auto loc
= access
.getLoc();
1337 if (analysis
.inAmendAccessSet(op
)) {
1338 // This array_access is associated with an array_amend and there is a
1339 // conflict. Make a copy to store into.
1340 auto result
= referenceToClone(loc
, rewriter
, access
);
1341 access
.replaceAllUsesWith(result
);
1342 rewriter
.replaceOp(access
, result
);
1343 return mlir::success();
1345 rewriter
.setInsertionPoint(op
);
1346 auto load
= mlir::cast
<ArrayLoadOp
>(useMap
.lookup(op
));
1347 auto coor
= genCoorOp(
1348 rewriter
, loc
, getEleTy(load
.getType()), toRefType(access
.getType()),
1349 load
.getMemref(), load
.getShape(), load
.getSlice(), access
.getIndices(),
1350 load
, access
->hasAttr(factory::attrFortranArrayOffsets()));
1351 rewriter
.replaceOp(access
, coor
);
1352 return mlir::success();
1356 /// An array_amend op is a marker to record which array access is being used to
1357 /// update an array value. After this pass runs, an array_amend has no
1358 /// semantics. We rewrite these to undefined values here to remove them while
1359 /// preserving SSA form.
1360 class ArrayAmendConversion
: public mlir::OpRewritePattern
<ArrayAmendOp
> {
1362 explicit ArrayAmendConversion(mlir::MLIRContext
*ctx
)
1363 : OpRewritePattern
{ctx
} {}
1366 matchAndRewrite(ArrayAmendOp amend
,
1367 mlir::PatternRewriter
&rewriter
) const override
{
1368 auto *op
= amend
.getOperation();
1369 rewriter
.setInsertionPoint(op
);
1370 auto loc
= amend
.getLoc();
1371 auto undef
= rewriter
.create
<UndefOp
>(loc
, amend
.getType());
1372 rewriter
.replaceOp(amend
, undef
.getResult());
1373 return mlir::success();
1377 class ArrayValueCopyConverter
1378 : public fir::impl::ArrayValueCopyBase
<ArrayValueCopyConverter
> {
1380 ArrayValueCopyConverter() = default;
1381 ArrayValueCopyConverter(const fir::ArrayValueCopyOptions
&options
)
1384 void runOnOperation() override
{
1385 auto func
= getOperation();
1386 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
1387 << func
.getName() << "'\n");
1388 auto *context
= &getContext();
1390 // Perform the conflict analysis.
1391 const ArrayCopyAnalysisBase
*analysis
;
1392 if (optimizeConflicts
)
1393 analysis
= &getAnalysis
<ArrayCopyAnalysisOptimized
>();
1395 analysis
= &getAnalysis
<ArrayCopyAnalysis
>();
1397 const auto &useMap
= analysis
->getUseMap();
1399 mlir::RewritePatternSet
patterns1(context
);
1400 patterns1
.insert
<ArrayFetchConversion
>(context
, useMap
);
1401 patterns1
.insert
<ArrayUpdateConversion
>(context
, *analysis
, useMap
);
1402 patterns1
.insert
<ArrayModifyConversion
>(context
, *analysis
, useMap
);
1403 patterns1
.insert
<ArrayAccessConversion
>(context
, *analysis
, useMap
);
1404 patterns1
.insert
<ArrayAmendConversion
>(context
);
1405 mlir::ConversionTarget
target(*context
);
1407 .addLegalDialect
<FIROpsDialect
, mlir::scf::SCFDialect
,
1408 mlir::arith::ArithDialect
, mlir::func::FuncDialect
>();
1409 target
.addIllegalOp
<ArrayAccessOp
, ArrayAmendOp
, ArrayFetchOp
,
1410 ArrayUpdateOp
, ArrayModifyOp
>();
1411 // Rewrite the array fetch and array update ops.
1413 mlir::applyPartialConversion(func
, target
, std::move(patterns1
)))) {
1414 mlir::emitError(mlir::UnknownLoc::get(context
),
1415 "failure in array-value-copy pass, phase 1");
1416 signalPassFailure();
1419 mlir::RewritePatternSet
patterns2(context
);
1420 patterns2
.insert
<ArrayLoadConversion
>(context
);
1421 patterns2
.insert
<ArrayMergeStoreConversion
>(context
);
1422 target
.addIllegalOp
<ArrayLoadOp
, ArrayMergeStoreOp
>();
1424 mlir::applyPartialConversion(func
, target
, std::move(patterns2
)))) {
1425 mlir::emitError(mlir::UnknownLoc::get(context
),
1426 "failure in array-value-copy pass, phase 2");
1427 signalPassFailure();
1433 std::unique_ptr
<mlir::Pass
>
1434 fir::createArrayValueCopyPass(fir::ArrayValueCopyOptions options
) {
1435 return std::make_unique
<ArrayValueCopyConverter
>(options
);