1 //===-- AffineDemotion.cpp -----------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This transformation is a prototype that demote affine dialects operations
10 // after optimizations to FIR loops operations.
11 // It is used after the AffinePromotion pass.
12 // It is not part of the production pipeline and would need more work in order
13 // to be used in production.
14 // More information can be found in this presentation:
15 // https://slides.com/rajanwalia/deck
17 //===----------------------------------------------------------------------===//
19 #include "flang/Optimizer/Dialect/FIRDialect.h"
20 #include "flang/Optimizer/Dialect/FIROps.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Transforms/Passes.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Affine/Utils.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/IR/BuiltinAttributes.h"
29 #include "mlir/IR/IntegerSet.h"
30 #include "mlir/IR/Visitors.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
38 #define GEN_PASS_DEF_AFFINEDIALECTDEMOTION
39 #include "flang/Optimizer/Transforms/Passes.h.inc"
42 #define DEBUG_TYPE "flang-affine-demotion"
49 class AffineLoadConversion
50 : public OpConversionPattern
<mlir::affine::AffineLoadOp
> {
52 using OpConversionPattern
<mlir::affine::AffineLoadOp
>::OpConversionPattern
;
55 matchAndRewrite(mlir::affine::AffineLoadOp op
, OpAdaptor adaptor
,
56 ConversionPatternRewriter
&rewriter
) const override
{
57 SmallVector
<Value
> indices(adaptor
.getIndices());
58 auto maybeExpandedMap
= affine::expandAffineMap(rewriter
, op
.getLoc(),
59 op
.getAffineMap(), indices
);
60 if (!maybeExpandedMap
)
63 auto coorOp
= rewriter
.create
<fir::CoordinateOp
>(
64 op
.getLoc(), fir::ReferenceType::get(op
.getResult().getType()),
65 adaptor
.getMemref(), *maybeExpandedMap
);
67 rewriter
.replaceOpWithNewOp
<fir::LoadOp
>(op
, coorOp
.getResult());
72 class AffineStoreConversion
73 : public OpConversionPattern
<mlir::affine::AffineStoreOp
> {
75 using OpConversionPattern
<mlir::affine::AffineStoreOp
>::OpConversionPattern
;
78 matchAndRewrite(mlir::affine::AffineStoreOp op
, OpAdaptor adaptor
,
79 ConversionPatternRewriter
&rewriter
) const override
{
80 SmallVector
<Value
> indices(op
.getIndices());
81 auto maybeExpandedMap
= affine::expandAffineMap(rewriter
, op
.getLoc(),
82 op
.getAffineMap(), indices
);
83 if (!maybeExpandedMap
)
86 auto coorOp
= rewriter
.create
<fir::CoordinateOp
>(
87 op
.getLoc(), fir::ReferenceType::get(op
.getValueToStore().getType()),
88 adaptor
.getMemref(), *maybeExpandedMap
);
89 rewriter
.replaceOpWithNewOp
<fir::StoreOp
>(op
, adaptor
.getValue(),
95 class ConvertConversion
: public mlir::OpRewritePattern
<fir::ConvertOp
> {
97 using OpRewritePattern::OpRewritePattern
;
99 matchAndRewrite(fir::ConvertOp op
,
100 mlir::PatternRewriter
&rewriter
) const override
{
101 if (mlir::isa
<mlir::MemRefType
>(op
.getRes().getType())) {
102 // due to index calculation moving to affine maps we still need to
103 // add converts for sequence types this has a side effect of losing
104 // some information about arrays with known dimensions by creating:
105 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
106 // !fir.ref<!fir.array<?xi32>>
108 mlir::dyn_cast
<fir::ReferenceType
>(op
.getValue().getType()))
109 if (auto arrTy
= mlir::dyn_cast
<fir::SequenceType
>(refTy
.getEleTy())) {
110 fir::SequenceType::Shape flatShape
= {
111 fir::SequenceType::getUnknownExtent()};
112 auto flatArrTy
= fir::SequenceType::get(flatShape
, arrTy
.getEleTy());
113 auto flatTy
= fir::ReferenceType::get(flatArrTy
);
114 rewriter
.replaceOpWithNewOp
<fir::ConvertOp
>(op
, flatTy
,
118 rewriter
.startOpModification(op
->getParentOp());
119 op
.getResult().replaceAllUsesWith(op
.getValue());
120 rewriter
.finalizeOpModification(op
->getParentOp());
121 rewriter
.eraseOp(op
);
127 mlir::Type
convertMemRef(mlir::MemRefType type
) {
128 return fir::SequenceType::get(SmallVector
<int64_t>(type
.getShape()),
129 type
.getElementType());
132 class StdAllocConversion
: public mlir::OpRewritePattern
<memref::AllocOp
> {
134 using OpRewritePattern::OpRewritePattern
;
136 matchAndRewrite(memref::AllocOp op
,
137 mlir::PatternRewriter
&rewriter
) const override
{
138 rewriter
.replaceOpWithNewOp
<fir::AllocaOp
>(op
, convertMemRef(op
.getType()),
144 class AffineDialectDemotion
145 : public fir::impl::AffineDialectDemotionBase
<AffineDialectDemotion
> {
147 void runOnOperation() override
{
148 auto *context
= &getContext();
149 auto function
= getOperation();
150 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
151 function
.print(llvm::dbgs()););
153 mlir::RewritePatternSet
patterns(context
);
154 patterns
.insert
<ConvertConversion
>(context
);
155 patterns
.insert
<AffineLoadConversion
>(context
);
156 patterns
.insert
<AffineStoreConversion
>(context
);
157 patterns
.insert
<StdAllocConversion
>(context
);
158 mlir::ConversionTarget
target(*context
);
159 target
.addIllegalOp
<memref::AllocOp
>();
160 target
.addDynamicallyLegalOp
<fir::ConvertOp
>([](fir::ConvertOp op
) {
161 if (mlir::isa
<mlir::MemRefType
>(op
.getRes().getType()))
166 .addLegalDialect
<FIROpsDialect
, mlir::scf::SCFDialect
,
167 mlir::arith::ArithDialect
, mlir::func::FuncDialect
>();
169 if (mlir::failed(mlir::applyPartialConversion(function
, target
,
170 std::move(patterns
)))) {
171 mlir::emitError(mlir::UnknownLoc::get(context
),
172 "error in converting affine dialect\n");
180 std::unique_ptr
<mlir::Pass
> fir::createAffineDemotionPass() {
181 return std::make_unique
<AffineDialectDemotion
>();