1 //===-- PreCGRewrite.cpp --------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Optimizer/CodeGen/CodeGen.h"
15 #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
16 #include "flang/Optimizer/CodeGen/CGOps.h"
17 #include "flang/Optimizer/Dialect/FIRDialect.h"
18 #include "flang/Optimizer/Dialect/FIROps.h"
19 #include "flang/Optimizer/Dialect/FIRType.h"
20 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
21 #include "mlir/IR/Iterators.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
27 #define GEN_PASS_DEF_CODEGENREWRITE
28 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
31 //===----------------------------------------------------------------------===//
32 // Codegen rewrite: rewriting of subgraphs of ops
33 //===----------------------------------------------------------------------===//
35 #define DEBUG_TYPE "flang-codegen-rewrite"
37 static void populateShape(llvm::SmallVectorImpl
<mlir::Value
> &vec
,
39 vec
.append(shape
.getExtents().begin(), shape
.getExtents().end());
42 // Operands of fir.shape_shift split into two vectors.
43 static void populateShapeAndShift(llvm::SmallVectorImpl
<mlir::Value
> &shapeVec
,
44 llvm::SmallVectorImpl
<mlir::Value
> &shiftVec
,
45 fir::ShapeShiftOp shift
) {
46 for (auto i
= shift
.getPairs().begin(), endIter
= shift
.getPairs().end();
48 shiftVec
.push_back(*i
++);
49 shapeVec
.push_back(*i
++);
53 static void populateShift(llvm::SmallVectorImpl
<mlir::Value
> &vec
,
55 vec
.append(shift
.getOrigins().begin(), shift
.getOrigins().end());
60 /// Convert fir.embox to the extended form where necessary.
62 /// The embox operation can take arguments that specify multidimensional array
63 /// properties at runtime. These properties may be shared between distinct
64 /// objects that have the same properties. Before we lower these small DAGs to
65 /// LLVM-IR, we gather all the information into a single extended operation. For
68 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
69 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
70 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
71 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
73 /// can be rewritten as
75 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
76 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
77 /// !fir.box<!fir.array<?xi32>>
79 class EmboxConversion
: public mlir::OpRewritePattern
<fir::EmboxOp
> {
81 using OpRewritePattern::OpRewritePattern
;
84 matchAndRewrite(fir::EmboxOp embox
,
85 mlir::PatternRewriter
&rewriter
) const override
{
86 // If the embox does not include a shape, then do not convert it
87 if (auto shapeVal
= embox
.getShape())
88 return rewriteDynamicShape(embox
, rewriter
, shapeVal
);
89 if (mlir::isa
<fir::ClassType
>(embox
.getType()))
90 TODO(embox
.getLoc(), "embox conversion for fir.class type");
91 if (auto boxTy
= mlir::dyn_cast
<fir::BoxType
>(embox
.getType()))
92 if (auto seqTy
= mlir::dyn_cast
<fir::SequenceType
>(boxTy
.getEleTy()))
93 if (!seqTy
.hasDynamicExtents())
94 return rewriteStaticShape(embox
, rewriter
, seqTy
);
95 return mlir::failure();
98 llvm::LogicalResult
rewriteStaticShape(fir::EmboxOp embox
,
99 mlir::PatternRewriter
&rewriter
,
100 fir::SequenceType seqTy
) const {
101 auto loc
= embox
.getLoc();
102 llvm::SmallVector
<mlir::Value
> shapeOpers
;
103 auto idxTy
= rewriter
.getIndexType();
104 for (auto ext
: seqTy
.getShape()) {
105 auto iAttr
= rewriter
.getIndexAttr(ext
);
106 auto extVal
= rewriter
.create
<mlir::arith::ConstantOp
>(loc
, idxTy
, iAttr
);
107 shapeOpers
.push_back(extVal
);
109 auto xbox
= rewriter
.create
<fir::cg::XEmboxOp
>(
110 loc
, embox
.getType(), embox
.getMemref(), shapeOpers
, std::nullopt
,
111 std::nullopt
, std::nullopt
, std::nullopt
, embox
.getTypeparams(),
112 embox
.getSourceBox(), embox
.getAllocatorIdxAttr());
113 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox
<< " to " << xbox
<< '\n');
114 rewriter
.replaceOp(embox
, xbox
.getOperation()->getResults());
115 return mlir::success();
118 llvm::LogicalResult
rewriteDynamicShape(fir::EmboxOp embox
,
119 mlir::PatternRewriter
&rewriter
,
120 mlir::Value shapeVal
) const {
121 auto loc
= embox
.getLoc();
122 llvm::SmallVector
<mlir::Value
> shapeOpers
;
123 llvm::SmallVector
<mlir::Value
> shiftOpers
;
124 if (auto shapeOp
= mlir::dyn_cast
<fir::ShapeOp
>(shapeVal
.getDefiningOp())) {
125 populateShape(shapeOpers
, shapeOp
);
128 mlir::dyn_cast
<fir::ShapeShiftOp
>(shapeVal
.getDefiningOp());
129 assert(shiftOp
&& "shape is neither fir.shape nor fir.shape_shift");
130 populateShapeAndShift(shapeOpers
, shiftOpers
, shiftOp
);
132 llvm::SmallVector
<mlir::Value
> sliceOpers
;
133 llvm::SmallVector
<mlir::Value
> subcompOpers
;
134 llvm::SmallVector
<mlir::Value
> substrOpers
;
135 if (auto s
= embox
.getSlice())
137 mlir::dyn_cast_or_null
<fir::SliceOp
>(s
.getDefiningOp())) {
138 sliceOpers
.assign(sliceOp
.getTriples().begin(),
139 sliceOp
.getTriples().end());
140 subcompOpers
.assign(sliceOp
.getFields().begin(),
141 sliceOp
.getFields().end());
142 substrOpers
.assign(sliceOp
.getSubstr().begin(),
143 sliceOp
.getSubstr().end());
145 auto xbox
= rewriter
.create
<fir::cg::XEmboxOp
>(
146 loc
, embox
.getType(), embox
.getMemref(), shapeOpers
, shiftOpers
,
147 sliceOpers
, subcompOpers
, substrOpers
, embox
.getTypeparams(),
148 embox
.getSourceBox(), embox
.getAllocatorIdxAttr());
149 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox
<< " to " << xbox
<< '\n');
150 rewriter
.replaceOp(embox
, xbox
.getOperation()->getResults());
151 return mlir::success();
155 /// Convert fir.rebox to the extended form where necessary.
159 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
160 /// !fir.box<!fir.array<?xi32>>
164 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
165 /// index, index) -> !fir.box<!fir.array<?xi32>>
167 class ReboxConversion
: public mlir::OpRewritePattern
<fir::ReboxOp
> {
169 using OpRewritePattern::OpRewritePattern
;
172 matchAndRewrite(fir::ReboxOp rebox
,
173 mlir::PatternRewriter
&rewriter
) const override
{
174 auto loc
= rebox
.getLoc();
175 llvm::SmallVector
<mlir::Value
> shapeOpers
;
176 llvm::SmallVector
<mlir::Value
> shiftOpers
;
177 if (auto shapeVal
= rebox
.getShape()) {
178 if (auto shapeOp
= mlir::dyn_cast
<fir::ShapeOp
>(shapeVal
.getDefiningOp()))
179 populateShape(shapeOpers
, shapeOp
);
180 else if (auto shiftOp
=
181 mlir::dyn_cast
<fir::ShapeShiftOp
>(shapeVal
.getDefiningOp()))
182 populateShapeAndShift(shapeOpers
, shiftOpers
, shiftOp
);
183 else if (auto shiftOp
=
184 mlir::dyn_cast
<fir::ShiftOp
>(shapeVal
.getDefiningOp()))
185 populateShift(shiftOpers
, shiftOp
);
187 return mlir::failure();
189 llvm::SmallVector
<mlir::Value
> sliceOpers
;
190 llvm::SmallVector
<mlir::Value
> subcompOpers
;
191 llvm::SmallVector
<mlir::Value
> substrOpers
;
192 if (auto s
= rebox
.getSlice())
194 mlir::dyn_cast_or_null
<fir::SliceOp
>(s
.getDefiningOp())) {
195 sliceOpers
.append(sliceOp
.getTriples().begin(),
196 sliceOp
.getTriples().end());
197 subcompOpers
.append(sliceOp
.getFields().begin(),
198 sliceOp
.getFields().end());
199 substrOpers
.append(sliceOp
.getSubstr().begin(),
200 sliceOp
.getSubstr().end());
203 auto xRebox
= rewriter
.create
<fir::cg::XReboxOp
>(
204 loc
, rebox
.getType(), rebox
.getBox(), shapeOpers
, shiftOpers
,
205 sliceOpers
, subcompOpers
, substrOpers
);
206 LLVM_DEBUG(llvm::dbgs()
207 << "rewriting " << rebox
<< " to " << xRebox
<< '\n');
208 rewriter
.replaceOp(rebox
, xRebox
.getOperation()->getResults());
209 return mlir::success();
213 /// Convert all fir.array_coor to the extended form.
217 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
218 /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
222 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
223 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
226 class ArrayCoorConversion
: public mlir::OpRewritePattern
<fir::ArrayCoorOp
> {
228 using OpRewritePattern::OpRewritePattern
;
231 matchAndRewrite(fir::ArrayCoorOp arrCoor
,
232 mlir::PatternRewriter
&rewriter
) const override
{
233 auto loc
= arrCoor
.getLoc();
234 llvm::SmallVector
<mlir::Value
> shapeOpers
;
235 llvm::SmallVector
<mlir::Value
> shiftOpers
;
236 if (auto shapeVal
= arrCoor
.getShape()) {
237 if (auto shapeOp
= mlir::dyn_cast
<fir::ShapeOp
>(shapeVal
.getDefiningOp()))
238 populateShape(shapeOpers
, shapeOp
);
239 else if (auto shiftOp
=
240 mlir::dyn_cast
<fir::ShapeShiftOp
>(shapeVal
.getDefiningOp()))
241 populateShapeAndShift(shapeOpers
, shiftOpers
, shiftOp
);
242 else if (auto shiftOp
=
243 mlir::dyn_cast
<fir::ShiftOp
>(shapeVal
.getDefiningOp()))
244 populateShift(shiftOpers
, shiftOp
);
246 return mlir::failure();
248 llvm::SmallVector
<mlir::Value
> sliceOpers
;
249 llvm::SmallVector
<mlir::Value
> subcompOpers
;
250 if (auto s
= arrCoor
.getSlice())
252 mlir::dyn_cast_or_null
<fir::SliceOp
>(s
.getDefiningOp())) {
253 sliceOpers
.append(sliceOp
.getTriples().begin(),
254 sliceOp
.getTriples().end());
255 subcompOpers
.append(sliceOp
.getFields().begin(),
256 sliceOp
.getFields().end());
257 assert(sliceOp
.getSubstr().empty() &&
258 "Don't allow substring operations on array_coor. This "
259 "restriction may be lifted in the future.");
261 auto xArrCoor
= rewriter
.create
<fir::cg::XArrayCoorOp
>(
262 loc
, arrCoor
.getType(), arrCoor
.getMemref(), shapeOpers
, shiftOpers
,
263 sliceOpers
, subcompOpers
, arrCoor
.getIndices(),
264 arrCoor
.getTypeparams());
265 LLVM_DEBUG(llvm::dbgs()
266 << "rewriting " << arrCoor
<< " to " << xArrCoor
<< '\n');
267 rewriter
.replaceOp(arrCoor
, xArrCoor
.getOperation()->getResults());
268 return mlir::success();
272 class DeclareOpConversion
: public mlir::OpRewritePattern
<fir::DeclareOp
> {
273 bool preserveDeclare
;
276 using OpRewritePattern::OpRewritePattern
;
277 DeclareOpConversion(mlir::MLIRContext
*ctx
, bool preserveDecl
)
278 : OpRewritePattern(ctx
), preserveDeclare(preserveDecl
) {}
281 matchAndRewrite(fir::DeclareOp declareOp
,
282 mlir::PatternRewriter
&rewriter
) const override
{
283 if (!preserveDeclare
) {
284 rewriter
.replaceOp(declareOp
, declareOp
.getMemref());
285 return mlir::success();
287 auto loc
= declareOp
.getLoc();
288 llvm::SmallVector
<mlir::Value
> shapeOpers
;
289 llvm::SmallVector
<mlir::Value
> shiftOpers
;
290 if (auto shapeVal
= declareOp
.getShape()) {
291 if (auto shapeOp
= mlir::dyn_cast
<fir::ShapeOp
>(shapeVal
.getDefiningOp()))
292 populateShape(shapeOpers
, shapeOp
);
293 else if (auto shiftOp
=
294 mlir::dyn_cast
<fir::ShapeShiftOp
>(shapeVal
.getDefiningOp()))
295 populateShapeAndShift(shapeOpers
, shiftOpers
, shiftOp
);
296 else if (auto shiftOp
=
297 mlir::dyn_cast
<fir::ShiftOp
>(shapeVal
.getDefiningOp()))
298 populateShift(shiftOpers
, shiftOp
);
300 return mlir::failure();
302 // FIXME: Add FortranAttrs and CudaAttrs
303 auto xDeclOp
= rewriter
.create
<fir::cg::XDeclareOp
>(
304 loc
, declareOp
.getType(), declareOp
.getMemref(), shapeOpers
, shiftOpers
,
305 declareOp
.getTypeparams(), declareOp
.getDummyScope(),
306 declareOp
.getUniqName());
307 LLVM_DEBUG(llvm::dbgs()
308 << "rewriting " << declareOp
<< " to " << xDeclOp
<< '\n');
309 rewriter
.replaceOp(declareOp
, xDeclOp
.getOperation()->getResults());
310 return mlir::success();
314 class DummyScopeOpConversion
315 : public mlir::OpRewritePattern
<fir::DummyScopeOp
> {
317 using OpRewritePattern::OpRewritePattern
;
320 matchAndRewrite(fir::DummyScopeOp dummyScopeOp
,
321 mlir::PatternRewriter
&rewriter
) const override
{
322 rewriter
.replaceOpWithNewOp
<fir::UndefOp
>(dummyScopeOp
,
323 dummyScopeOp
.getType());
324 return mlir::success();
328 /// Simple DCE to erase fir.shape/shift/slice/unused shape operands after this
329 /// pass (fir.shape and like have no codegen).
330 /// mlir::RegionDCE is expensive and requires running
331 /// mlir::eraseUnreachableBlocks. It does things that are not needed here, like
332 /// removing unused block arguments. fir.shape/shift/slice cannot be block
334 /// This helper does a naive backward walk of the IR. It is not even guaranteed
335 /// to walk blocks according to backward dominance, but that is good enough for
336 /// what is done here, fir.shape/shift/slice have no usages anymore. The
337 /// backward walk allows getting rid of most of the unused operands, it is not a
338 /// problem to leave some in the weird cases.
339 static void simpleDCE(mlir::RewriterBase
&rewriter
, mlir::Operation
*op
) {
340 op
->walk
<mlir::WalkOrder::PostOrder
, mlir::ReverseIterator
>(
341 [&](mlir::Operation
*subOp
) {
342 if (mlir::isOpTriviallyDead(subOp
))
343 rewriter
.eraseOp(subOp
);
347 class CodeGenRewrite
: public fir::impl::CodeGenRewriteBase
<CodeGenRewrite
> {
349 using CodeGenRewriteBase
<CodeGenRewrite
>::CodeGenRewriteBase
;
351 void runOnOperation() override final
{
352 mlir::ModuleOp mod
= getOperation();
354 auto &context
= getContext();
355 mlir::ConversionTarget
target(context
);
356 target
.addLegalDialect
<mlir::arith::ArithDialect
, fir::FIROpsDialect
,
357 fir::FIRCodeGenDialect
, mlir::func::FuncDialect
>();
358 target
.addIllegalOp
<fir::ArrayCoorOp
>();
359 target
.addIllegalOp
<fir::ReboxOp
>();
360 target
.addIllegalOp
<fir::DeclareOp
>();
361 target
.addIllegalOp
<fir::DummyScopeOp
>();
362 target
.addDynamicallyLegalOp
<fir::EmboxOp
>([](fir::EmboxOp embox
) {
363 return !(embox
.getShape() ||
364 mlir::isa
<fir::SequenceType
>(
365 mlir::cast
<fir::BaseBoxType
>(embox
.getType()).getEleTy()));
367 mlir::RewritePatternSet
patterns(&context
);
368 fir::populatePreCGRewritePatterns(patterns
, preserveDeclare
);
370 mlir::applyPartialConversion(mod
, target
, std::move(patterns
)))) {
371 mlir::emitError(mlir::UnknownLoc::get(&context
),
372 "error in running the pre-codegen conversions");
376 // Erase any residual (fir.shape, fir.slice...).
377 mlir::IRRewriter
rewriter(&context
);
378 simpleDCE(rewriter
, mod
.getOperation());
384 void fir::populatePreCGRewritePatterns(mlir::RewritePatternSet
&patterns
,
385 bool preserveDeclare
) {
386 patterns
.insert
<EmboxConversion
, ArrayCoorConversion
, ReboxConversion
,
387 DummyScopeOpConversion
>(patterns
.getContext());
388 patterns
.add
<DeclareOpConversion
>(patterns
.getContext(), preserveDeclare
);