1 //===-- AffineDemotion.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This transformation is a prototype that demote affine dialects operations
10 // after optimizations to FIR loops operations.
11 // It is used after the AffinePromotion pass.
12 // It is not part of the production pipeline and would need more work in order
13 // to be used in production.
14 // More information can be found in this presentation:
15 // https://slides.com/rajanwalia/deck
17 //===----------------------------------------------------------------------===//
19 #include "flang/Optimizer/Dialect/FIRDialect.h"
20 #include "flang/Optimizer/Dialect/FIROps.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Transforms/Passes.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Affine/Utils.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/IR/BuiltinAttributes.h"
29 #include "mlir/IR/IntegerSet.h"
30 #include "mlir/IR/Visitors.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
38 #define GEN_PASS_DEF_AFFINEDIALECTDEMOTION
39 #include "flang/Optimizer/Transforms/Passes.h.inc"
42 #define DEBUG_TYPE "flang-affine-demotion"
49 class AffineLoadConversion
: public OpConversionPattern
<mlir::AffineLoadOp
> {
51 using OpConversionPattern
<mlir::AffineLoadOp
>::OpConversionPattern
;
54 matchAndRewrite(mlir::AffineLoadOp op
, OpAdaptor adaptor
,
55 ConversionPatternRewriter
&rewriter
) const override
{
56 SmallVector
<Value
> indices(adaptor
.getIndices());
57 auto maybeExpandedMap
=
58 expandAffineMap(rewriter
, op
.getLoc(), op
.getAffineMap(), indices
);
59 if (!maybeExpandedMap
)
62 auto coorOp
= rewriter
.create
<fir::CoordinateOp
>(
63 op
.getLoc(), fir::ReferenceType::get(op
.getResult().getType()),
64 adaptor
.getMemref(), *maybeExpandedMap
);
66 rewriter
.replaceOpWithNewOp
<fir::LoadOp
>(op
, coorOp
.getResult());
71 class AffineStoreConversion
: public OpConversionPattern
<mlir::AffineStoreOp
> {
73 using OpConversionPattern
<mlir::AffineStoreOp
>::OpConversionPattern
;
76 matchAndRewrite(mlir::AffineStoreOp op
, OpAdaptor adaptor
,
77 ConversionPatternRewriter
&rewriter
) const override
{
78 SmallVector
<Value
> indices(op
.getIndices());
79 auto maybeExpandedMap
=
80 expandAffineMap(rewriter
, op
.getLoc(), op
.getAffineMap(), indices
);
81 if (!maybeExpandedMap
)
84 auto coorOp
= rewriter
.create
<fir::CoordinateOp
>(
85 op
.getLoc(), fir::ReferenceType::get(op
.getValueToStore().getType()),
86 adaptor
.getMemref(), *maybeExpandedMap
);
87 rewriter
.replaceOpWithNewOp
<fir::StoreOp
>(op
, adaptor
.getValue(),
93 class ConvertConversion
: public mlir::OpRewritePattern
<fir::ConvertOp
> {
95 using OpRewritePattern::OpRewritePattern
;
97 matchAndRewrite(fir::ConvertOp op
,
98 mlir::PatternRewriter
&rewriter
) const override
{
99 if (op
.getRes().getType().isa
<mlir::MemRefType
>()) {
100 // due to index calculation moving to affine maps we still need to
101 // add converts for sequence types this has a side effect of losing
102 // some information about arrays with known dimensions by creating:
103 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
104 // !fir.ref<!fir.array<?xi32>>
105 if (auto refTy
= op
.getValue().getType().dyn_cast
<fir::ReferenceType
>())
106 if (auto arrTy
= refTy
.getEleTy().dyn_cast
<fir::SequenceType
>()) {
107 fir::SequenceType::Shape flatShape
= {
108 fir::SequenceType::getUnknownExtent()};
109 auto flatArrTy
= fir::SequenceType::get(flatShape
, arrTy
.getEleTy());
110 auto flatTy
= fir::ReferenceType::get(flatArrTy
);
111 rewriter
.replaceOpWithNewOp
<fir::ConvertOp
>(op
, flatTy
,
115 rewriter
.startRootUpdate(op
->getParentOp());
116 op
.getResult().replaceAllUsesWith(op
.getValue());
117 rewriter
.finalizeRootUpdate(op
->getParentOp());
118 rewriter
.eraseOp(op
);
124 mlir::Type
convertMemRef(mlir::MemRefType type
) {
125 return fir::SequenceType::get(
126 SmallVector
<int64_t>(type
.getShape().begin(), type
.getShape().end()),
127 type
.getElementType());
130 class StdAllocConversion
: public mlir::OpRewritePattern
<memref::AllocOp
> {
132 using OpRewritePattern::OpRewritePattern
;
134 matchAndRewrite(memref::AllocOp op
,
135 mlir::PatternRewriter
&rewriter
) const override
{
136 rewriter
.replaceOpWithNewOp
<fir::AllocaOp
>(op
, convertMemRef(op
.getType()),
142 class AffineDialectDemotion
143 : public fir::impl::AffineDialectDemotionBase
<AffineDialectDemotion
> {
145 void runOnOperation() override
{
146 auto *context
= &getContext();
147 auto function
= getOperation();
148 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
149 function
.print(llvm::dbgs()););
151 mlir::RewritePatternSet
patterns(context
);
152 patterns
.insert
<ConvertConversion
>(context
);
153 patterns
.insert
<AffineLoadConversion
>(context
);
154 patterns
.insert
<AffineStoreConversion
>(context
);
155 patterns
.insert
<StdAllocConversion
>(context
);
156 mlir::ConversionTarget
target(*context
);
157 target
.addIllegalOp
<memref::AllocOp
>();
158 target
.addDynamicallyLegalOp
<fir::ConvertOp
>([](fir::ConvertOp op
) {
159 if (op
.getRes().getType().isa
<mlir::MemRefType
>())
164 .addLegalDialect
<FIROpsDialect
, mlir::scf::SCFDialect
,
165 mlir::arith::ArithDialect
, mlir::func::FuncDialect
>();
167 if (mlir::failed(mlir::applyPartialConversion(function
, target
,
168 std::move(patterns
)))) {
169 mlir::emitError(mlir::UnknownLoc::get(context
),
170 "error in converting affine dialect\n");
178 std::unique_ptr
<mlir::Pass
> fir::createAffineDemotionPass() {
179 return std::make_unique
<AffineDialectDemotion
>();