LAA: improve code in getStrideFromPointer (NFC) (#124780)
[llvm-project.git] / flang / lib / Optimizer / HLFIR / Transforms / InlineHLFIRAssign.cpp
blob249976d5509b0c0e8f6c411f3996a11cc2657735
1 //===- InlineHLFIRAssign.cpp - Inline hlfir.assign ops --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Transform hlfir.assign array operations into loop nests performing element
9 // per element assignments. The inlining is done for trivial data types always,
10 // though, we may add performance/code-size heuristics in future.
11 //===----------------------------------------------------------------------===//
13 #include "flang/Optimizer/Analysis/AliasAnalysis.h"
14 #include "flang/Optimizer/Builder/FIRBuilder.h"
15 #include "flang/Optimizer/Builder/HLFIRTools.h"
16 #include "flang/Optimizer/HLFIR/HLFIROps.h"
17 #include "flang/Optimizer/HLFIR/Passes.h"
18 #include "flang/Optimizer/OpenMP/Passes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Support/LLVM.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 namespace hlfir {
25 #define GEN_PASS_DEF_INLINEHLFIRASSIGN
26 #include "flang/Optimizer/HLFIR/Passes.h.inc"
27 } // namespace hlfir
29 #define DEBUG_TYPE "inline-hlfir-assign"
31 namespace {
32 /// Expand hlfir.assign of array RHS to array LHS into a loop nest
33 /// of element-by-element assignments:
34 /// hlfir.assign %4 to %5 : !fir.ref<!fir.array<3x3xf32>>,
35 /// !fir.ref<!fir.array<3x3xf32>>
36 /// into:
37 /// fir.do_loop %arg1 = %c1 to %c3 step %c1 unordered {
38 /// fir.do_loop %arg2 = %c1 to %c3 step %c1 unordered {
39 /// %6 = hlfir.designate %4 (%arg2, %arg1) :
40 /// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32>
41 /// %7 = fir.load %6 : !fir.ref<f32>
42 /// %8 = hlfir.designate %5 (%arg2, %arg1) :
43 /// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32>
44 /// hlfir.assign %7 to %8 : f32, !fir.ref<f32>
45 /// }
46 /// }
47 ///
48 /// The transformation is correct only when LHS and RHS do not alias.
49 /// When RHS is an array expression, then there is no aliasing.
50 /// This transformation does not support runtime checking for
51 /// non-conforming LHS/RHS arrays' shapes currently.
52 class InlineHLFIRAssignConversion
53 : public mlir::OpRewritePattern<hlfir::AssignOp> {
54 public:
55 using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
57 llvm::LogicalResult
58 matchAndRewrite(hlfir::AssignOp assign,
59 mlir::PatternRewriter &rewriter) const override {
60 if (assign.isAllocatableAssignment())
61 return rewriter.notifyMatchFailure(assign,
62 "AssignOp may imply allocation");
64 hlfir::Entity rhs{assign.getRhs()};
66 if (!rhs.isArray())
67 return rewriter.notifyMatchFailure(assign,
68 "AssignOp's RHS is not an array");
70 mlir::Type rhsEleTy = rhs.getFortranElementType();
71 if (!fir::isa_trivial(rhsEleTy))
72 return rewriter.notifyMatchFailure(
73 assign, "AssignOp's RHS data type is not trivial");
75 hlfir::Entity lhs{assign.getLhs()};
76 if (!lhs.isArray())
77 return rewriter.notifyMatchFailure(assign,
78 "AssignOp's LHS is not an array");
80 mlir::Type lhsEleTy = lhs.getFortranElementType();
81 if (!fir::isa_trivial(lhsEleTy))
82 return rewriter.notifyMatchFailure(
83 assign, "AssignOp's LHS data type is not trivial");
85 if (lhsEleTy != rhsEleTy)
86 return rewriter.notifyMatchFailure(assign,
87 "RHS/LHS element types mismatch");
89 if (!mlir::isa<hlfir::ExprType>(rhs.getType())) {
90 // If RHS is not an hlfir.expr, then we should prove that
91 // LHS and RHS do not alias.
92 // TODO: if they may alias, we can insert hlfir.as_expr for RHS,
93 // and proceed with the inlining.
94 fir::AliasAnalysis aliasAnalysis;
95 mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs);
96 // TODO: use areIdenticalOrDisjointSlices() from
97 // OptimizedBufferization.cpp to check if we can still do the expansion.
98 if (!aliasRes.isNo()) {
99 LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n"
100 << "\tLHS: " << lhs << "\n"
101 << "\tRHS: " << rhs << "\n"
102 << "\tALIAS: " << aliasRes << "\n");
103 return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias");
107 mlir::Location loc = assign->getLoc();
108 fir::FirOpBuilder builder(rewriter, assign.getOperation());
109 builder.setInsertionPoint(assign);
110 rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs);
111 lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
112 mlir::Value shape = hlfir::genShape(loc, builder, lhs);
113 llvm::SmallVector<mlir::Value> extents =
114 hlfir::getIndexExtents(loc, builder, shape);
115 hlfir::LoopNest loopNest =
116 hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
117 flangomp::shouldUseWorkshareLowering(assign));
118 builder.setInsertionPointToStart(loopNest.body);
119 auto rhsArrayElement =
120 hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
121 rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
122 auto lhsArrayElement =
123 hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
124 builder.create<hlfir::AssignOp>(loc, rhsArrayElement, lhsArrayElement);
125 rewriter.eraseOp(assign);
126 return mlir::success();
130 class InlineHLFIRAssignPass
131 : public hlfir::impl::InlineHLFIRAssignBase<InlineHLFIRAssignPass> {
132 public:
133 void runOnOperation() override {
134 mlir::MLIRContext *context = &getContext();
136 mlir::GreedyRewriteConfig config;
137 // Prevent the pattern driver from merging blocks.
138 config.enableRegionSimplification =
139 mlir::GreedySimplifyRegionLevel::Disabled;
141 mlir::RewritePatternSet patterns(context);
142 patterns.insert<InlineHLFIRAssignConversion>(context);
144 if (mlir::failed(mlir::applyPatternsGreedily(
145 getOperation(), std::move(patterns), config))) {
146 mlir::emitError(getOperation()->getLoc(),
147 "failure in hlfir.assign inlining");
148 signalPassFailure();
152 } // namespace