[RISCV] Change func to funct in RISCVInstrInfoXqci.td. NFC (#119669)
[llvm-project.git] / flang / lib / Optimizer / Transforms / ArrayValueCopy.cpp
blob8544d17f62248dfb4d71927a14b44f6380d73a64
1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Builder/BoxValue.h"
10 #include "flang/Optimizer/Builder/FIRBuilder.h"
11 #include "flang/Optimizer/Builder/Factory.h"
12 #include "flang/Optimizer/Builder/Runtime/Derived.h"
13 #include "flang/Optimizer/Builder/Todo.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
16 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
17 #include "flang/Optimizer/Transforms/Passes.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/Support/Debug.h"
23 namespace fir {
24 #define GEN_PASS_DEF_ARRAYVALUECOPY
25 #include "flang/Optimizer/Transforms/Passes.h.inc"
26 } // namespace fir
28 #define DEBUG_TYPE "flang-array-value-copy"
30 using namespace fir;
31 using namespace mlir;
33 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
35 namespace {
37 /// Array copy analysis.
38 /// Perform an interference analysis between array values.
39 ///
40 /// Lowering will generate a sequence of the following form.
41 /// ```mlir
42 /// %a_1 = fir.array_load %array_1(%shape) : ...
43 /// ...
44 /// %a_j = fir.array_load %array_j(%shape) : ...
45 /// ...
46 /// %a_n = fir.array_load %array_n(%shape) : ...
47 /// ...
48 /// %v_i = fir.array_fetch %a_i, ...
49 /// %a_j1 = fir.array_update %a_j, ...
50 /// ...
51 /// fir.array_merge_store %a_j, %a_jn to %array_j : ...
52 /// ```
53 ///
54 /// The analysis is to determine if there are any conflicts. A conflict is when
55 /// one the following cases occurs.
56 ///
57 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
58 /// loaded from the same array memory reference (array_j) but with a different
59 /// shape as the other array values a_i, where i != j. [Possible overlapping
60 /// arrays.]
61 ///
62 /// 2. There is either an array_fetch or array_update of a_j with a different
63 /// set of index values. [Possible loop-carried dependence.]
64 ///
65 /// If none of the array values overlap in storage and the accesses are not
66 /// loop-carried, then the arrays are conflict-free and no copies are required.
67 class ArrayCopyAnalysisBase {
68 public:
69 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
70 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
71 using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>;
72 using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>;
74 ArrayCopyAnalysisBase(mlir::Operation *op, bool optimized)
75 : operation{op}, optimizeConflicts(optimized) {
76 construct(op);
78 virtual ~ArrayCopyAnalysisBase() = default;
80 mlir::Operation *getOperation() const { return operation; }
82 /// Return true iff the `array_merge_store` has potential conflicts.
83 bool hasPotentialConflict(mlir::Operation *op) const {
84 LLVM_DEBUG(llvm::dbgs()
85 << "looking for a conflict on " << *op
86 << " and the set has a total of " << conflicts.size() << '\n');
87 return conflicts.contains(op);
90 /// Return the use map.
91 /// The use map maps array access, amend, fetch and update operations back to
92 /// the array load that is the original source of the array value.
93 /// It maps an array_load to an array_merge_store, if and only if the loaded
94 /// array value has pending modifications to be merged.
95 const OperationUseMapT &getUseMap() const { return useMap; }
97 /// Return the set of array_access ops directly associated with array_amend
98 /// ops.
99 bool inAmendAccessSet(mlir::Operation *op) const {
100 return amendAccesses.count(op);
103 /// For ArrayLoad `load`, return the transitive set of all OpOperands.
104 UseSetT getLoadUseSet(mlir::Operation *load) const {
105 assert(loadMapSets.count(load) && "analysis missed an array load?");
106 return loadMapSets.lookup(load);
109 void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions,
110 ArrayLoadOp load);
112 private:
113 void construct(mlir::Operation *topLevelOp);
115 mlir::Operation *operation; // operation that analysis ran upon
116 ConflictSetT conflicts; // set of conflicts (loads and merge stores)
117 OperationUseMapT useMap;
118 LoadMapSetsT loadMapSets;
119 // Set of array_access ops associated with array_amend ops.
120 AmendAccessSetT amendAccesses;
121 bool optimizeConflicts;
124 // Optimized array copy analysis that takes into account Fortran
125 // variable attributes to prove that no conflict is possible
126 // and reduce the number of temporary arrays.
127 class ArrayCopyAnalysisOptimized : public ArrayCopyAnalysisBase {
128 public:
129 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysisOptimized)
131 ArrayCopyAnalysisOptimized(mlir::Operation *op)
132 : ArrayCopyAnalysisBase(op, /*optimized=*/true) {}
135 // Unoptimized array copy analysis used at O0.
136 class ArrayCopyAnalysis : public ArrayCopyAnalysisBase {
137 public:
138 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis)
140 ArrayCopyAnalysis(mlir::Operation *op)
141 : ArrayCopyAnalysisBase(op, /*optimized=*/false) {}
143 } // namespace
145 namespace {
146 /// Helper class to collect all array operations that produced an array value.
147 class ReachCollector {
148 public:
149 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
150 mlir::Region *loopRegion)
151 : reach{reach}, loopRegion{loopRegion} {}
153 void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) {
154 if (range.empty()) {
155 collectArrayMentionFrom(op, mlir::Value{});
156 return;
158 for (mlir::Value v : range)
159 collectArrayMentionFrom(v);
162 // Collect all the array_access ops in `block`. This recursively looks into
163 // blocks in ops with regions.
164 // FIXME: This is temporarily relying on the array_amend appearing in a
165 // do_loop Region. This phase ordering assumption can be eliminated by using
166 // dominance information to find the array_access ops or by scanning the
167 // transitive closure of the amending array_access's users and the defs that
168 // reach them.
169 void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result,
170 mlir::Block *block) {
171 for (auto &op : *block) {
172 if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) {
173 LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n');
174 result.push_back(access);
175 continue;
177 for (auto &region : op.getRegions())
178 for (auto &bb : region.getBlocks())
179 collectAccesses(result, &bb);
183 void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) {
184 // `val` is defined by an Op, process the defining Op.
185 // If `val` is defined by a region containing Op, we want to drill down
186 // and through that Op's region(s).
187 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
188 auto popFn = [&](auto rop) {
189 assert(val && "op must have a result value");
190 auto resNum = mlir::cast<mlir::OpResult>(val).getResultNumber();
191 llvm::SmallVector<mlir::Value> results;
192 rop.resultToSourceOps(results, resNum);
193 for (auto u : results)
194 collectArrayMentionFrom(u);
196 if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) {
197 popFn(rop);
198 return;
200 if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) {
201 popFn(rop);
202 return;
204 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
205 popFn(rop);
206 return;
208 if (auto box = mlir::dyn_cast<EmboxOp>(op)) {
209 for (auto *user : box.getMemref().getUsers())
210 if (user != op)
211 collectArrayMentionFrom(user, user->getResults());
212 return;
214 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
215 if (opIsInsideLoops(mergeStore))
216 collectArrayMentionFrom(mergeStore.getSequence());
217 return;
220 if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
221 // Look for any stores inside the loops, and collect an array operation
222 // that produced the value being stored to it.
223 for (auto *user : op->getUsers())
224 if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
225 if (opIsInsideLoops(store))
226 collectArrayMentionFrom(store.getValue());
227 return;
230 // Scan the uses of amend's memref
231 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) {
232 reach.push_back(op);
233 llvm::SmallVector<ArrayAccessOp> accesses;
234 collectAccesses(accesses, op->getBlock());
235 for (auto access : accesses)
236 collectArrayMentionFrom(access.getResult());
239 // Otherwise, Op does not contain a region so just chase its operands.
240 if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp,
241 ArrayFetchOp>(op)) {
242 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
243 reach.push_back(op);
246 // Include all array_access ops using an array_load.
247 if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op))
248 for (auto *user : arrLd.getResult().getUsers())
249 if (mlir::isa<ArrayAccessOp>(user)) {
250 LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n");
251 reach.push_back(user);
254 // Array modify assignment is performed on the result. So the analysis must
255 // look at the what is done with the result.
256 if (mlir::isa<ArrayModifyOp>(op))
257 for (auto *user : op->getResult(0).getUsers())
258 followUsers(user);
260 if (mlir::isa<fir::CallOp>(op)) {
261 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
262 reach.push_back(op);
265 for (auto u : op->getOperands())
266 collectArrayMentionFrom(u);
269 void collectArrayMentionFrom(mlir::BlockArgument ba) {
270 auto *parent = ba.getOwner()->getParentOp();
271 // If inside an Op holding a region, the block argument corresponds to an
272 // argument passed to the containing Op.
273 auto popFn = [&](auto rop) {
274 collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
276 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
277 popFn(rop);
278 return;
280 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
281 popFn(rop);
282 return;
284 // Otherwise, a block argument is provided via the pred blocks.
285 for (auto *pred : ba.getOwner()->getPredecessors()) {
286 auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
287 collectArrayMentionFrom(u);
291 // Recursively trace operands to find all array operations relating to the
292 // values merged.
293 void collectArrayMentionFrom(mlir::Value val) {
294 if (!val || visited.contains(val))
295 return;
296 visited.insert(val);
298 // Process a block argument.
299 if (auto ba = mlir::dyn_cast<mlir::BlockArgument>(val)) {
300 collectArrayMentionFrom(ba);
301 return;
304 // Process an Op.
305 if (auto *op = val.getDefiningOp()) {
306 collectArrayMentionFrom(op, val);
307 return;
310 emitFatalError(val.getLoc(), "unhandled value");
313 /// Return all ops that produce the array value that is stored into the
314 /// `array_merge_store`.
315 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
316 mlir::Value seq) {
317 reach.clear();
318 mlir::Region *loopRegion = nullptr;
319 if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp()))
320 loopRegion = &doLoop->getRegion(0);
321 ReachCollector collector(reach, loopRegion);
322 collector.collectArrayMentionFrom(seq);
325 private:
326 /// Is \op inside the loop nest region ?
327 /// FIXME: replace this structural dependence with graph properties.
328 bool opIsInsideLoops(mlir::Operation *op) const {
329 auto *region = op->getParentRegion();
330 while (region) {
331 if (region == loopRegion)
332 return true;
333 region = region->getParentRegion();
335 return false;
338 /// Recursively trace the use of an operation results, calling
339 /// collectArrayMentionFrom on the direct and indirect user operands.
340 void followUsers(mlir::Operation *op) {
341 for (auto userOperand : op->getOperands())
342 collectArrayMentionFrom(userOperand);
343 // Go through potential converts/coordinate_op.
344 for (auto indirectUser : op->getUsers())
345 followUsers(indirectUser);
348 llvm::SmallVectorImpl<mlir::Operation *> &reach;
349 llvm::SmallPtrSet<mlir::Value, 16> visited;
350 /// Region of the loops nest that produced the array value.
351 mlir::Region *loopRegion;
353 } // namespace
355 /// Find all the array operations that access the array value that is loaded by
356 /// the array load operation, `load`.
357 void ArrayCopyAnalysisBase::arrayMentions(
358 llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) {
359 mentions.clear();
360 auto lmIter = loadMapSets.find(load);
361 if (lmIter != loadMapSets.end()) {
362 for (auto *opnd : lmIter->second) {
363 auto *owner = opnd->getOwner();
364 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
365 ArrayModifyOp>(owner))
366 mentions.push_back(owner);
368 return;
371 UseSetT visited;
372 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
374 auto appendToQueue = [&](mlir::Value val) {
375 for (auto &use : val.getUses())
376 if (!visited.count(&use)) {
377 visited.insert(&use);
378 queue.push_back(&use);
382 // Build the set of uses of `original`.
383 // let USES = { uses of original fir.load }
384 appendToQueue(load);
386 // Process the worklist until done.
387 while (!queue.empty()) {
388 mlir::OpOperand *operand = queue.pop_back_val();
389 mlir::Operation *owner = operand->getOwner();
390 if (!owner)
391 continue;
392 auto structuredLoop = [&](auto ro) {
393 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
394 int64_t arg = blockArg.getArgNumber();
395 mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
396 appendToQueue(output);
397 appendToQueue(blockArg);
400 // TODO: this need to be updated to use the control-flow interface.
401 auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
402 if (operands.empty())
403 return;
405 // Check if this operand is within the range.
406 unsigned operandIndex = operand->getOperandNumber();
407 unsigned operandsStart = operands.getBeginOperandIndex();
408 if (operandIndex < operandsStart ||
409 operandIndex >= (operandsStart + operands.size()))
410 return;
412 // Index the successor.
413 unsigned argIndex = operandIndex - operandsStart;
414 appendToQueue(dest->getArgument(argIndex));
416 // Thread uses into structured loop bodies and return value uses.
417 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
418 structuredLoop(ro);
419 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
420 structuredLoop(ro);
421 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
422 // Thread any uses of fir.if that return the marked array value.
423 mlir::Operation *parent = rs->getParentRegion()->getParentOp();
424 if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent))
425 appendToQueue(ifOp.getResult(operand->getOperandNumber()));
426 } else if (mlir::isa<ArrayFetchOp>(owner)) {
427 // Keep track of array value fetches.
428 LLVM_DEBUG(llvm::dbgs()
429 << "add fetch {" << *owner << "} to array value set\n");
430 mentions.push_back(owner);
431 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
432 // Keep track of array value updates and thread the return value uses.
433 LLVM_DEBUG(llvm::dbgs()
434 << "add update {" << *owner << "} to array value set\n");
435 mentions.push_back(owner);
436 appendToQueue(update.getResult());
437 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
438 // Keep track of array value modification and thread the return value
439 // uses.
440 LLVM_DEBUG(llvm::dbgs()
441 << "add modify {" << *owner << "} to array value set\n");
442 mentions.push_back(owner);
443 appendToQueue(update.getResult(1));
444 } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) {
445 mentions.push_back(owner);
446 } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) {
447 mentions.push_back(owner);
448 appendToQueue(amend.getResult());
449 } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
450 branchOp(br.getDest(), br.getDestOperands());
451 } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
452 branchOp(br.getTrueDest(), br.getTrueOperands());
453 branchOp(br.getFalseDest(), br.getFalseOperands());
454 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
455 // do nothing
456 } else {
457 llvm::report_fatal_error("array value reached unexpected op");
460 loadMapSets.insert({load, visited});
463 static bool hasPointerType(mlir::Type type) {
464 if (auto boxTy = mlir::dyn_cast<BoxType>(type))
465 type = boxTy.getEleTy();
466 return mlir::isa<fir::PointerType>(type);
469 // This is a NF performance hack. It makes a simple test that the slices of the
470 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive.
471 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) {
472 // If the same array_load, then no further testing is warranted.
473 if (ld.getResult() == st.getOriginal())
474 return false;
476 auto getSliceOp = [](mlir::Value val) -> SliceOp {
477 if (!val)
478 return {};
479 auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp());
480 if (!sliceOp)
481 return {};
482 return sliceOp;
485 auto ldSlice = getSliceOp(ld.getSlice());
486 auto stSlice = getSliceOp(st.getSlice());
487 if (!ldSlice || !stSlice)
488 return false;
490 // Resign on subobject slices.
491 if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() ||
492 !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty())
493 return false;
495 // Crudely test that the two slices do not overlap by looking for the
496 // following general condition. If the slices look like (i:j) and (j+1:k) then
497 // these ranges do not overlap. The addend must be a constant.
498 auto ldTriples = ldSlice.getTriples();
499 auto stTriples = stSlice.getTriples();
500 const auto size = ldTriples.size();
501 if (size != stTriples.size())
502 return false;
504 auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
505 auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
506 auto *op = v.getDefiningOp();
507 while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op))
508 op = conv.getValue().getDefiningOp();
509 return op;
512 auto isPositiveConstant = [](mlir::Value v) -> bool {
513 if (auto conOp =
514 mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
515 if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(conOp.getValue()))
516 return iattr.getInt() > 0;
517 return false;
520 auto *op1 = removeConvert(v1);
521 auto *op2 = removeConvert(v2);
522 if (!op1 || !op2)
523 return false;
524 if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
525 if ((addi.getLhs().getDefiningOp() == op1 &&
526 isPositiveConstant(addi.getRhs())) ||
527 (addi.getRhs().getDefiningOp() == op1 &&
528 isPositiveConstant(addi.getLhs())))
529 return true;
530 if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
531 if (subi.getLhs().getDefiningOp() == op2 &&
532 isPositiveConstant(subi.getRhs()))
533 return true;
534 return false;
537 for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) {
538 // If both are loop invariant, skip to the next triple.
539 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) &&
540 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) {
541 // Unless either is a vector index, then be conservative.
542 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) ||
543 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp()))
544 return false;
545 continue;
547 // If identical, skip to the next triple.
548 if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] &&
549 ldTriples[i + 2] == stTriples[i + 2])
550 continue;
551 // If ubound and lbound are the same with a constant offset, skip to the
552 // next triple.
553 if (displacedByConstant(ldTriples[i + 1], stTriples[i]) ||
554 displacedByConstant(stTriples[i + 1], ldTriples[i]))
555 continue;
556 return false;
558 LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld
559 << " and " << st << ", which is not a conflict\n");
560 return true;
563 /// Is there a conflict between the array value that was updated and to be
564 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
565 /// the updated value?
566 /// If `optimize` is true, use the variable attributes to prove that
567 /// there is no conflict.
568 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
569 ArrayMergeStoreOp st, bool optimize) {
570 mlir::Value load;
571 mlir::Value addr = st.getMemref();
572 const bool storeHasPointerType = hasPointerType(addr.getType());
573 for (auto *op : reach)
574 if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) {
575 mlir::Type ldTy = ld.getMemref().getType();
576 auto globalOpName = mlir::OperationName(fir::GlobalOp::getOperationName(),
577 ld.getContext());
578 if (ld.getMemref() == addr) {
579 if (mutuallyExclusiveSliceRange(ld, st))
580 continue;
581 if (ld.getResult() != st.getOriginal())
582 return true;
583 if (load) {
584 // TODO: extend this to allow checking if the first `load` and this
585 // `ld` are mutually exclusive accesses but not identical.
586 return true;
588 load = ld;
589 } else if (storeHasPointerType) {
590 if (optimize && !hasPointerType(ldTy) &&
591 !valueMayHaveFirAttributes(
592 ld.getMemref(),
593 {getTargetAttrName(),
594 fir::GlobalOp::getTargetAttrName(globalOpName).strref()}))
595 continue;
597 return true;
598 } else if (hasPointerType(ldTy)) {
599 if (optimize && !storeHasPointerType &&
600 !valueMayHaveFirAttributes(
601 addr,
602 {getTargetAttrName(),
603 fir::GlobalOp::getTargetAttrName(globalOpName).strref()}))
604 continue;
606 return true;
608 // TODO: Check if types can also allow ruling out some cases. For now,
609 // the fact that equivalences is using pointer attribute to enforce
610 // aliasing is preventing any attempt to do so, and in general, it may
611 // be wrong to use this if any of the types is a complex or a derived
612 // for which it is possible to create a pointer to a part with a
613 // different type than the whole, although this deserve some more
614 // investigation because existing compiler behavior seem to diverge
615 // here.
617 return false;
620 /// Is there an access vector conflict on the array being merged into? If the
621 /// access vectors diverge, then assume that there are potentially overlapping
622 /// loop-carried references.
623 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) {
624 if (mentions.size() < 2)
625 return false;
626 llvm::SmallVector<mlir::Value> indices;
627 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size()
628 << " mentions on the list\n");
629 bool valSeen = false;
630 bool refSeen = false;
631 for (auto *op : mentions) {
632 llvm::SmallVector<mlir::Value> compareVector;
633 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
634 valSeen = true;
635 if (indices.empty()) {
636 indices = u.getIndices();
637 continue;
639 compareVector = u.getIndices();
640 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
641 valSeen = true;
642 if (indices.empty()) {
643 indices = f.getIndices();
644 continue;
646 compareVector = f.getIndices();
647 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
648 valSeen = true;
649 if (indices.empty()) {
650 indices = f.getIndices();
651 continue;
653 compareVector = f.getIndices();
654 } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) {
655 refSeen = true;
656 if (indices.empty()) {
657 indices = f.getIndices();
658 continue;
660 compareVector = f.getIndices();
661 } else if (mlir::isa<ArrayAmendOp>(op)) {
662 refSeen = true;
663 continue;
664 } else {
665 mlir::emitError(op->getLoc(), "unexpected operation in analysis");
667 if (compareVector.size() != indices.size() ||
668 llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) {
669 return std::get<0>(pair) != std::get<1>(pair);
671 return true;
672 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
674 return valSeen && refSeen;
677 /// With element-by-reference semantics, an amended array with more than once
678 /// access to the same loaded array are conservatively considered a conflict.
679 /// Note: the array copy can still be eliminated in subsequent optimizations.
680 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) {
681 LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size()
682 << '\n');
683 if (mentions.size() < 3)
684 return false;
685 unsigned amendCount = 0;
686 unsigned accessCount = 0;
687 for (auto *op : mentions) {
688 if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) {
689 LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n");
690 return true;
692 if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) {
693 LLVM_DEBUG(llvm::dbgs()
694 << "conflict: multiple accesses of array value\n");
695 return true;
697 if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) {
698 LLVM_DEBUG(llvm::dbgs()
699 << "conflict: array value has both uses by-value and uses "
700 "by-reference. conservative assumption.\n");
701 return true;
704 return false;
707 static mlir::Operation *
708 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) {
709 for (auto *op : mentions)
710 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op))
711 return amend.getMemref().getDefiningOp();
712 return {};
715 // Are any conflicts present? The conflicts detected here are described above.
716 static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
717 llvm::ArrayRef<mlir::Operation *> mentions,
718 ArrayMergeStoreOp st, bool optimize) {
719 return conflictOnLoad(reach, st, optimize) || conflictOnMerge(mentions);
722 // Assume that any call to a function that uses host-associations will be
723 // modifying the output array.
724 static bool
725 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) {
726 return llvm::any_of(reaches, [](mlir::Operation *op) {
727 if (auto call = mlir::dyn_cast<fir::CallOp>(op))
728 if (auto callee = mlir::dyn_cast<mlir::SymbolRefAttr>(
729 call.getCallableForCallee())) {
730 auto module = op->getParentOfType<mlir::ModuleOp>();
731 return isInternalProcedure(
732 module.lookupSymbol<mlir::func::FuncOp>(callee));
734 return false;
738 /// Constructor of the array copy analysis.
739 /// This performs the analysis and saves the intermediate results.
740 void ArrayCopyAnalysisBase::construct(mlir::Operation *topLevelOp) {
741 topLevelOp->walk([&](Operation *op) {
742 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
743 llvm::SmallVector<mlir::Operation *> values;
744 ReachCollector::reachingValues(values, st.getSequence());
745 bool callConflict = conservativeCallConflict(values);
746 llvm::SmallVector<mlir::Operation *> mentions;
747 arrayMentions(mentions,
748 mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
749 bool conflict = conflictDetected(values, mentions, st, optimizeConflicts);
750 bool refConflict = conflictOnReference(mentions);
751 if (callConflict || conflict || refConflict) {
752 LLVM_DEBUG(llvm::dbgs()
753 << "CONFLICT: copies required for " << st << '\n'
754 << " adding conflicts on: " << *op << " and "
755 << st.getOriginal() << '\n');
756 conflicts.insert(op);
757 conflicts.insert(st.getOriginal().getDefiningOp());
758 if (auto *access = amendingAccess(mentions))
759 amendAccesses.insert(access);
761 auto *ld = st.getOriginal().getDefiningOp();
762 LLVM_DEBUG(llvm::dbgs()
763 << "map: adding {" << *ld << " -> " << st << "}\n");
764 useMap.insert({ld, op});
765 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
766 llvm::SmallVector<mlir::Operation *> mentions;
767 arrayMentions(mentions, load);
768 LLVM_DEBUG(llvm::dbgs() << "process load: " << load
769 << ", mentions: " << mentions.size() << '\n');
770 for (auto *acc : mentions) {
771 LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n');
772 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
773 ArrayModifyOp>(acc)) {
774 if (useMap.count(acc)) {
775 mlir::emitError(
776 load.getLoc(),
777 "The parallel semantics of multiple array_merge_stores per "
778 "array_load are not supported.");
779 continue;
781 LLVM_DEBUG(llvm::dbgs()
782 << "map: adding {" << *acc << "} -> {" << load << "}\n");
783 useMap.insert({acc, op});
790 //===----------------------------------------------------------------------===//
791 // Conversions for converting out of array value form.
792 //===----------------------------------------------------------------------===//
794 namespace {
795 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
796 public:
797 using OpRewritePattern::OpRewritePattern;
799 llvm::LogicalResult
800 matchAndRewrite(ArrayLoadOp load,
801 mlir::PatternRewriter &rewriter) const override {
802 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
803 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
804 return mlir::success();
808 class ArrayMergeStoreConversion
809 : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
810 public:
811 using OpRewritePattern::OpRewritePattern;
813 llvm::LogicalResult
814 matchAndRewrite(ArrayMergeStoreOp store,
815 mlir::PatternRewriter &rewriter) const override {
816 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
817 rewriter.eraseOp(store);
818 return mlir::success();
821 } // namespace
823 static mlir::Type getEleTy(mlir::Type ty) {
824 auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty));
825 // FIXME: keep ptr/heap/ref information.
826 return ReferenceType::get(eleTy);
829 // This is an unsafe way to deduce this (won't be true in internal
830 // procedure or inside select-rank for assumed-size). Only here to satisfy
831 // legacy code until removed.
832 static bool isAssumedSize(llvm::SmallVectorImpl<mlir::Value> &extents) {
833 if (extents.empty())
834 return false;
835 auto cstLen = fir::getIntIfConstant(extents.back());
836 return cstLen.has_value() && *cstLen == -1;
839 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
840 static bool getAdjustedExtents(mlir::Location loc,
841 mlir::PatternRewriter &rewriter,
842 ArrayLoadOp arrLoad,
843 llvm::SmallVectorImpl<mlir::Value> &result,
844 mlir::Value shape) {
845 bool copyUsingSlice = false;
846 auto *shapeOp = shape.getDefiningOp();
847 if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) {
848 auto e = s.getExtents();
849 result.insert(result.end(), e.begin(), e.end());
850 } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) {
851 auto e = s.getExtents();
852 result.insert(result.end(), e.begin(), e.end());
853 } else {
854 emitFatalError(loc, "not a fir.shape/fir.shape_shift op");
856 auto idxTy = rewriter.getIndexType();
857 if (isAssumedSize(result)) {
858 // Use slice information to compute the extent of the column.
859 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
860 mlir::Value size = one;
861 if (mlir::Value sliceArg = arrLoad.getSlice()) {
862 if (auto sliceOp =
863 mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) {
864 auto triples = sliceOp.getTriples();
865 const std::size_t tripleSize = triples.size();
866 auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
867 FirOpBuilder builder(rewriter, module);
868 size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3],
869 triples[tripleSize - 2],
870 triples[tripleSize - 1], idxTy);
871 copyUsingSlice = true;
874 result[result.size() - 1] = size;
876 return copyUsingSlice;
879 /// Place the extents of the array load, \p arrLoad, into \p result and
880 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is
881 /// loading a `!fir.box`, code will be generated to read the extents from the
882 /// boxed value, and the retunred shape Op will be built with the extents read
883 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or
884 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true
885 /// if slicing of the output array is to be done in the copy-in/copy-out rather
886 /// than in the elemental computation step.
887 static mlir::Value getOrReadExtentsAndShapeOp(
888 mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad,
889 llvm::SmallVectorImpl<mlir::Value> &result, bool &copyUsingSlice) {
890 assert(result.empty());
891 if (arrLoad->hasAttr(fir::getOptionalAttrName()))
892 fir::emitFatalError(
893 loc, "shapes from array load of OPTIONAL arrays must not be used");
894 if (auto boxTy = mlir::dyn_cast<BoxType>(arrLoad.getMemref().getType())) {
895 auto rank =
896 mlir::cast<SequenceType>(dyn_cast_ptrOrBoxEleTy(boxTy)).getDimension();
897 auto idxTy = rewriter.getIndexType();
898 for (decltype(rank) dim = 0; dim < rank; ++dim) {
899 auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim);
900 auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy,
901 arrLoad.getMemref(), dimVal);
902 result.emplace_back(dimInfo.getResult(1));
904 if (!arrLoad.getShape()) {
905 auto shapeType = ShapeType::get(rewriter.getContext(), rank);
906 return rewriter.create<ShapeOp>(loc, shapeType, result);
908 auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>();
909 auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank);
910 llvm::SmallVector<mlir::Value> shapeShiftOperands;
911 for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) {
912 shapeShiftOperands.push_back(lb);
913 shapeShiftOperands.push_back(extent);
915 return rewriter.create<ShapeShiftOp>(loc, shapeShiftType,
916 shapeShiftOperands);
918 copyUsingSlice =
919 getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape());
920 return arrLoad.getShape();
923 static mlir::Type toRefType(mlir::Type ty) {
924 if (fir::isa_ref_type(ty))
925 return ty;
926 return fir::ReferenceType::get(ty);
929 static llvm::SmallVector<mlir::Value>
930 getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder,
931 ArrayLoadOp arrLoad, mlir::Type ty) {
932 if (mlir::isa<BoxType>(ty))
933 return {};
934 return fir::factory::getTypeParams(loc, builder, arrLoad);
937 static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter,
938 mlir::Location loc, mlir::Type eleTy,
939 mlir::Type resTy, mlir::Value alloc,
940 mlir::Value shape, mlir::Value slice,
941 mlir::ValueRange indices, ArrayLoadOp load,
942 bool skipOrig = false) {
943 llvm::SmallVector<mlir::Value> originated;
944 if (skipOrig)
945 originated.assign(indices.begin(), indices.end());
946 else
947 originated = factory::originateIndices(loc, rewriter, alloc.getType(),
948 shape, indices);
949 auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType());
950 assert(seqTy && mlir::isa<SequenceType>(seqTy));
951 const auto dimension = mlir::cast<SequenceType>(seqTy).getDimension();
952 auto module = load->getParentOfType<mlir::ModuleOp>();
953 FirOpBuilder builder(rewriter, module);
954 auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType());
955 mlir::Value result = rewriter.create<ArrayCoorOp>(
956 loc, eleTy, alloc, shape, slice,
957 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
958 typeparams);
959 if (dimension < originated.size())
960 result = rewriter.create<fir::CoordinateOp>(
961 loc, resTy, result,
962 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
963 return result;
966 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder,
967 ArrayLoadOp load, CharacterType charTy) {
968 auto charLenTy = builder.getCharacterLengthType();
969 if (charTy.hasDynamicLen()) {
970 if (mlir::isa<BoxType>(load.getMemref().getType())) {
971 // The loaded array is an emboxed value. Get the CHARACTER length from
972 // the box value.
973 auto eleSzInBytes =
974 builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref());
975 auto kindSize =
976 builder.getKindMap().getCharacterBitsize(charTy.getFKind());
977 auto kindByteSize =
978 builder.createIntegerConstant(loc, charLenTy, kindSize / 8);
979 return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes,
980 kindByteSize);
982 // The loaded array is a (set of) unboxed values. If the CHARACTER's
983 // length is not a constant, it must be provided as a type parameter to
984 // the array_load.
985 auto typeparams = load.getTypeparams();
986 assert(typeparams.size() > 0 && "expected type parameters on array_load");
987 return typeparams.back();
989 // The typical case: the length of the CHARACTER is a compile-time
990 // constant that is encoded in the type information.
991 return builder.createIntegerConstant(loc, charLenTy, charTy.getLen());
993 /// Generate a shallow array copy. This is used for both copy-in and copy-out.
994 template <bool CopyIn>
995 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
996 mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
997 mlir::Value sliceOp, ArrayLoadOp arrLoad) {
998 auto insPt = rewriter.saveInsertionPoint();
999 llvm::SmallVector<mlir::Value> indices;
1000 llvm::SmallVector<mlir::Value> extents;
1001 bool copyUsingSlice =
1002 getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp);
1003 auto idxTy = rewriter.getIndexType();
1004 // Build loop nest from column to row.
1005 for (auto sh : llvm::reverse(extents)) {
1006 auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh);
1007 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
1008 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
1009 auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one);
1010 auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one);
1011 rewriter.setInsertionPointToStart(loop.getBody());
1012 indices.push_back(loop.getInductionVar());
1014 // Reverse the indices so they are in column-major order.
1015 std::reverse(indices.begin(), indices.end());
1016 auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
1017 FirOpBuilder builder(rewriter, module);
1018 auto fromAddr = rewriter.create<ArrayCoorOp>(
1019 loc, getEleTy(src.getType()), src, shapeOp,
1020 CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
1021 factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices),
1022 getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
1023 auto toAddr = rewriter.create<ArrayCoorOp>(
1024 loc, getEleTy(dst.getType()), dst, shapeOp,
1025 !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
1026 factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices),
1027 getTypeParamsIfRawData(loc, builder, arrLoad, dst.getType()));
1028 auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType()));
1029 // Copy from (to) object to (from) temp copy of same object.
1030 if (auto charTy = mlir::dyn_cast<CharacterType>(eleTy)) {
1031 auto len = getCharacterLen(loc, builder, arrLoad, charTy);
1032 CharBoxValue toChar(toAddr, len);
1033 CharBoxValue fromChar(fromAddr, len);
1034 factory::genScalarAssignment(builder, loc, toChar, fromChar);
1035 } else {
1036 if (hasDynamicSize(eleTy))
1037 TODO(loc, "copy element of dynamic size");
1038 factory::genScalarAssignment(builder, loc, toAddr, fromAddr);
1040 rewriter.restoreInsertionPoint(insPt);
1043 /// The array load may be either a boxed or unboxed value. If the value is
1044 /// boxed, we read the type parameters from the boxed value.
1045 static llvm::SmallVector<mlir::Value>
1046 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter,
1047 ArrayLoadOp load) {
1048 if (load.getTypeparams().empty()) {
1049 auto eleTy =
1050 unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType()));
1051 if (hasDynamicSize(eleTy)) {
1052 if (auto charTy = mlir::dyn_cast<CharacterType>(eleTy)) {
1053 assert(mlir::isa<BoxType>(load.getMemref().getType()));
1054 auto module = load->getParentOfType<mlir::ModuleOp>();
1055 FirOpBuilder builder(rewriter, module);
1056 return {getCharacterLen(loc, builder, load, charTy)};
1058 TODO(loc, "unhandled dynamic type parameters");
1060 return {};
1062 return load.getTypeparams();
1065 static llvm::SmallVector<mlir::Value>
1066 findNonconstantExtents(mlir::Type memrefTy,
1067 llvm::ArrayRef<mlir::Value> extents) {
1068 llvm::SmallVector<mlir::Value> nce;
1069 auto arrTy = unwrapPassByRefType(memrefTy);
1070 auto seqTy = mlir::cast<SequenceType>(arrTy);
1071 for (auto [s, x] : llvm::zip(seqTy.getShape(), extents))
1072 if (s == SequenceType::getUnknownExtent())
1073 nce.emplace_back(x);
1074 if (extents.size() > seqTy.getShape().size())
1075 for (auto x : extents.drop_front(seqTy.getShape().size()))
1076 nce.emplace_back(x);
1077 return nce;
1080 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any
1081 /// allocatable direct components of the array elements with an unallocated
1082 /// status. Returns the temporary address as well as a callback to generate the
1083 /// temporary clean-up once it has been used. The clean-up will take care of
1084 /// deallocating all the element allocatable components that may have been
1085 /// allocated while using the temporary.
1086 static std::pair<mlir::Value,
1087 std::function<void(mlir::PatternRewriter &rewriter)>>
1088 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter,
1089 ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents,
1090 mlir::Value shape) {
1091 mlir::Type baseType = load.getMemref().getType();
1092 llvm::SmallVector<mlir::Value> nonconstantExtents =
1093 findNonconstantExtents(baseType, extents);
1094 llvm::SmallVector<mlir::Value> typeParams =
1095 genArrayLoadTypeParameters(loc, rewriter, load);
1096 mlir::Value allocmem = rewriter.create<AllocMemOp>(
1097 loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents);
1098 mlir::Type eleType =
1099 fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType));
1100 if (fir::isRecordWithAllocatableMember(eleType)) {
1101 // The allocatable component descriptors need to be set to a clean
1102 // deallocated status before anything is done with them.
1103 mlir::Value box = rewriter.create<fir::EmboxOp>(
1104 loc, fir::BoxType::get(allocmem.getType()), allocmem, shape,
1105 /*slice=*/mlir::Value{}, typeParams);
1106 auto module = load->getParentOfType<mlir::ModuleOp>();
1107 FirOpBuilder builder(rewriter, module);
1108 runtime::genDerivedTypeInitialize(builder, loc, box);
1109 // Any allocatable component that may have been allocated must be
1110 // deallocated during the clean-up.
1111 auto cleanup = [=](mlir::PatternRewriter &r) {
1112 FirOpBuilder builder(r, module);
1113 runtime::genDerivedTypeDestroy(builder, loc, box);
1114 r.create<FreeMemOp>(loc, allocmem);
1116 return {allocmem, cleanup};
1118 auto cleanup = [=](mlir::PatternRewriter &r) {
1119 r.create<FreeMemOp>(loc, allocmem);
1121 return {allocmem, cleanup};
1124 namespace {
1125 /// Conversion of fir.array_update and fir.array_modify Ops.
1126 /// If there is a conflict for the update, then we need to perform a
1127 /// copy-in/copy-out to preserve the original values of the array. If there is
1128 /// no conflict, then it is save to eschew making any copies.
1129 template <typename ArrayOp>
1130 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
1131 public:
1132 // TODO: Implement copy/swap semantics?
1133 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
1134 const ArrayCopyAnalysisBase &a,
1135 const OperationUseMapT &m)
1136 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
1138 /// The array_access, \p access, is to be to a cloned copy due to a potential
1139 /// conflict. Uses copy-in/copy-out semantics and not copy/swap.
1140 mlir::Value referenceToClone(mlir::Location loc,
1141 mlir::PatternRewriter &rewriter,
1142 ArrayOp access) const {
1143 LLVM_DEBUG(llvm::dbgs()
1144 << "generating copy-in/copy-out loops for " << access << '\n');
1145 auto *op = access.getOperation();
1146 auto *loadOp = useMap.lookup(op);
1147 auto load = mlir::cast<ArrayLoadOp>(loadOp);
1148 auto eleTy = access.getType();
1149 rewriter.setInsertionPoint(loadOp);
1150 // Copy in.
1151 llvm::SmallVector<mlir::Value> extents;
1152 bool copyUsingSlice = false;
1153 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1154 copyUsingSlice);
1155 auto [allocmem, genTempCleanUp] =
1156 allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1157 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1158 load.getMemref(), shapeOp, load.getSlice(),
1159 load);
1160 // Generate the reference for the access.
1161 rewriter.setInsertionPoint(op);
1162 auto coor = genCoorOp(
1163 rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp,
1164 copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(),
1165 load, access->hasAttr(factory::attrFortranArrayOffsets()));
1166 // Copy out.
1167 auto *storeOp = useMap.lookup(loadOp);
1168 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1169 rewriter.setInsertionPoint(storeOp);
1170 // Copy out.
1171 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(),
1172 allocmem, shapeOp, store.getSlice(), load);
1173 genTempCleanUp(rewriter);
1174 return coor;
1177 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
1178 /// temp and the LHS if the analysis found potential overlaps between the RHS
1179 /// and LHS arrays. The element copy generator must be provided in \p
1180 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
1181 /// Returns the address of the LHS element inside the loop and the LHS
1182 /// ArrayLoad result.
1183 std::pair<mlir::Value, mlir::Value>
1184 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
1185 ArrayOp update,
1186 const std::function<void(mlir::Value)> &assignElement,
1187 mlir::Type lhsEltRefType) const {
1188 auto *op = update.getOperation();
1189 auto *loadOp = useMap.lookup(op);
1190 auto load = mlir::cast<ArrayLoadOp>(loadOp);
1191 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
1192 if (analysis.hasPotentialConflict(loadOp)) {
1193 // If there is a conflict between the arrays, then we copy the lhs array
1194 // to a temporary, update the temporary, and copy the temporary back to
1195 // the lhs array. This yields Fortran's copy-in copy-out array semantics.
1196 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
1197 rewriter.setInsertionPoint(loadOp);
1198 // Copy in.
1199 llvm::SmallVector<mlir::Value> extents;
1200 bool copyUsingSlice = false;
1201 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1202 copyUsingSlice);
1203 auto [allocmem, genTempCleanUp] =
1204 allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1206 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1207 load.getMemref(), shapeOp, load.getSlice(),
1208 load);
1209 rewriter.setInsertionPoint(op);
1210 auto coor = genCoorOp(
1211 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
1212 shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
1213 update.getIndices(), load,
1214 update->hasAttr(factory::attrFortranArrayOffsets()));
1215 assignElement(coor);
1216 auto *storeOp = useMap.lookup(loadOp);
1217 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1218 rewriter.setInsertionPoint(storeOp);
1219 // Copy out.
1220 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter,
1221 store.getMemref(), allocmem, shapeOp,
1222 store.getSlice(), load);
1223 genTempCleanUp(rewriter);
1224 return {coor, load.getResult()};
1226 // Otherwise, when there is no conflict (a possible loop-carried
1227 // dependence), the lhs array can be updated in place.
1228 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
1229 rewriter.setInsertionPoint(op);
1230 auto coorTy = getEleTy(load.getType());
1231 auto coor =
1232 genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(),
1233 load.getShape(), load.getSlice(), update.getIndices(), load,
1234 update->hasAttr(factory::attrFortranArrayOffsets()));
1235 assignElement(coor);
1236 return {coor, load.getResult()};
1239 protected:
1240 const ArrayCopyAnalysisBase &analysis;
1241 const OperationUseMapT &useMap;
1244 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
1245 public:
1246 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
1247 const ArrayCopyAnalysisBase &a,
1248 const OperationUseMapT &m)
1249 : ArrayUpdateConversionBase{ctx, a, m} {}
1251 llvm::LogicalResult
1252 matchAndRewrite(ArrayUpdateOp update,
1253 mlir::PatternRewriter &rewriter) const override {
1254 auto loc = update.getLoc();
1255 auto assignElement = [&](mlir::Value coor) {
1256 auto input = update.getMerge();
1257 if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) {
1258 emitFatalError(loc, "array_update on references not supported");
1259 } else {
1260 rewriter.create<fir::StoreOp>(loc, input, coor);
1263 auto lhsEltRefType = toRefType(update.getMerge().getType());
1264 auto [_, lhsLoadResult] = materializeAssignment(
1265 loc, rewriter, update, assignElement, lhsEltRefType);
1266 update.replaceAllUsesWith(lhsLoadResult);
1267 rewriter.replaceOp(update, lhsLoadResult);
1268 return mlir::success();
1272 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
1273 public:
1274 explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
1275 const ArrayCopyAnalysisBase &a,
1276 const OperationUseMapT &m)
1277 : ArrayUpdateConversionBase{ctx, a, m} {}
1279 llvm::LogicalResult
1280 matchAndRewrite(ArrayModifyOp modify,
1281 mlir::PatternRewriter &rewriter) const override {
1282 auto loc = modify.getLoc();
1283 auto assignElement = [](mlir::Value) {
1284 // Assignment already materialized by lowering using lhs element address.
1286 auto lhsEltRefType = modify.getResult(0).getType();
1287 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
1288 loc, rewriter, modify, assignElement, lhsEltRefType);
1289 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1290 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1291 return mlir::success();
1295 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
1296 public:
1297 explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
1298 const OperationUseMapT &m)
1299 : OpRewritePattern{ctx}, useMap{m} {}
1301 llvm::LogicalResult
1302 matchAndRewrite(ArrayFetchOp fetch,
1303 mlir::PatternRewriter &rewriter) const override {
1304 auto *op = fetch.getOperation();
1305 rewriter.setInsertionPoint(op);
1306 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1307 auto loc = fetch.getLoc();
1308 auto coor = genCoorOp(
1309 rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()),
1310 load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(),
1311 load, fetch->hasAttr(factory::attrFortranArrayOffsets()));
1312 if (isa_ref_type(fetch.getType()))
1313 rewriter.replaceOp(fetch, coor);
1314 else
1315 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
1316 return mlir::success();
1319 private:
1320 const OperationUseMapT &useMap;
1323 /// As array_access op is like an array_fetch op, except that it does not imply
1324 /// a load op. (It operates in the reference domain.)
1325 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> {
1326 public:
1327 explicit ArrayAccessConversion(mlir::MLIRContext *ctx,
1328 const ArrayCopyAnalysisBase &a,
1329 const OperationUseMapT &m)
1330 : ArrayUpdateConversionBase{ctx, a, m} {}
1332 llvm::LogicalResult
1333 matchAndRewrite(ArrayAccessOp access,
1334 mlir::PatternRewriter &rewriter) const override {
1335 auto *op = access.getOperation();
1336 auto loc = access.getLoc();
1337 if (analysis.inAmendAccessSet(op)) {
1338 // This array_access is associated with an array_amend and there is a
1339 // conflict. Make a copy to store into.
1340 auto result = referenceToClone(loc, rewriter, access);
1341 access.replaceAllUsesWith(result);
1342 rewriter.replaceOp(access, result);
1343 return mlir::success();
1345 rewriter.setInsertionPoint(op);
1346 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1347 auto coor = genCoorOp(
1348 rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()),
1349 load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(),
1350 load, access->hasAttr(factory::attrFortranArrayOffsets()));
1351 rewriter.replaceOp(access, coor);
1352 return mlir::success();
1356 /// An array_amend op is a marker to record which array access is being used to
1357 /// update an array value. After this pass runs, an array_amend has no
1358 /// semantics. We rewrite these to undefined values here to remove them while
1359 /// preserving SSA form.
1360 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> {
1361 public:
1362 explicit ArrayAmendConversion(mlir::MLIRContext *ctx)
1363 : OpRewritePattern{ctx} {}
1365 llvm::LogicalResult
1366 matchAndRewrite(ArrayAmendOp amend,
1367 mlir::PatternRewriter &rewriter) const override {
1368 auto *op = amend.getOperation();
1369 rewriter.setInsertionPoint(op);
1370 auto loc = amend.getLoc();
1371 auto undef = rewriter.create<UndefOp>(loc, amend.getType());
1372 rewriter.replaceOp(amend, undef.getResult());
1373 return mlir::success();
1377 class ArrayValueCopyConverter
1378 : public fir::impl::ArrayValueCopyBase<ArrayValueCopyConverter> {
1379 public:
1380 ArrayValueCopyConverter() = default;
1381 ArrayValueCopyConverter(const fir::ArrayValueCopyOptions &options)
1382 : Base(options) {}
1384 void runOnOperation() override {
1385 auto func = getOperation();
1386 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
1387 << func.getName() << "'\n");
1388 auto *context = &getContext();
1390 // Perform the conflict analysis.
1391 const ArrayCopyAnalysisBase *analysis;
1392 if (optimizeConflicts)
1393 analysis = &getAnalysis<ArrayCopyAnalysisOptimized>();
1394 else
1395 analysis = &getAnalysis<ArrayCopyAnalysis>();
1397 const auto &useMap = analysis->getUseMap();
1399 mlir::RewritePatternSet patterns1(context);
1400 patterns1.insert<ArrayFetchConversion>(context, useMap);
1401 patterns1.insert<ArrayUpdateConversion>(context, *analysis, useMap);
1402 patterns1.insert<ArrayModifyConversion>(context, *analysis, useMap);
1403 patterns1.insert<ArrayAccessConversion>(context, *analysis, useMap);
1404 patterns1.insert<ArrayAmendConversion>(context);
1405 mlir::ConversionTarget target(*context);
1406 target
1407 .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
1408 mlir::arith::ArithDialect, mlir::func::FuncDialect>();
1409 target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp,
1410 ArrayUpdateOp, ArrayModifyOp>();
1411 // Rewrite the array fetch and array update ops.
1412 if (mlir::failed(
1413 mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
1414 mlir::emitError(mlir::UnknownLoc::get(context),
1415 "failure in array-value-copy pass, phase 1");
1416 signalPassFailure();
1419 mlir::RewritePatternSet patterns2(context);
1420 patterns2.insert<ArrayLoadConversion>(context);
1421 patterns2.insert<ArrayMergeStoreConversion>(context);
1422 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
1423 if (mlir::failed(
1424 mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
1425 mlir::emitError(mlir::UnknownLoc::get(context),
1426 "failure in array-value-copy pass, phase 2");
1427 signalPassFailure();
1431 } // namespace
1433 std::unique_ptr<mlir::Pass>
1434 fir::createArrayValueCopyPass(fir::ArrayValueCopyOptions options) {
1435 return std::make_unique<ArrayValueCopyConverter>(options);