1 //===-- AssumedRankOpConversion.cpp ---------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Common/Fortran.h"
10 #include "flang/Lower/BuiltinModules.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Builder/Runtime/Support.h"
13 #include "flang/Optimizer/Builder/Todo.h"
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Support/TypeCode.h"
17 #include "flang/Optimizer/Support/Utils.h"
18 #include "flang/Optimizer/Transforms/Passes.h"
19 #include "flang/Runtime/support.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #define GEN_PASS_DEF_ASSUMEDRANKOPCONVERSION
27 #include "flang/Optimizer/Transforms/Passes.h.inc"
35 static int getCFIAttribute(mlir::Type boxType
) {
36 if (fir::isAllocatableType(boxType
))
37 return CFI_attribute_allocatable
;
38 if (fir::isPointerType(boxType
))
39 return CFI_attribute_pointer
;
40 return CFI_attribute_other
;
43 static Fortran::runtime::LowerBoundModifier
44 getLowerBoundModifier(fir::LowerBoundModifierAttribute modifier
) {
46 case fir::LowerBoundModifierAttribute::Preserve
:
47 return Fortran::runtime::LowerBoundModifier::Preserve
;
48 case fir::LowerBoundModifierAttribute::SetToOnes
:
49 return Fortran::runtime::LowerBoundModifier::SetToOnes
;
50 case fir::LowerBoundModifierAttribute::SetToZeroes
:
51 return Fortran::runtime::LowerBoundModifier::SetToZeroes
;
53 llvm_unreachable("bad modifier code");
56 class ReboxAssumedRankConv
57 : public mlir::OpRewritePattern
<fir::ReboxAssumedRankOp
> {
59 using OpRewritePattern::OpRewritePattern
;
61 ReboxAssumedRankConv(mlir::MLIRContext
*context
,
62 mlir::SymbolTable
*symbolTable
, fir::KindMapping kindMap
)
63 : mlir::OpRewritePattern
<fir::ReboxAssumedRankOp
>(context
),
64 symbolTable
{symbolTable
}, kindMap
{kindMap
} {};
67 matchAndRewrite(fir::ReboxAssumedRankOp rebox
,
68 mlir::PatternRewriter
&rewriter
) const override
{
69 fir::FirOpBuilder builder
{rewriter
, kindMap
, symbolTable
};
70 mlir::Location loc
= rebox
.getLoc();
71 auto newBoxType
= mlir::cast
<fir::BaseBoxType
>(rebox
.getType());
72 mlir::Type newMaxRankBoxType
=
73 newBoxType
.getBoxTypeWithNewShape(Fortran::common::maxRank
);
74 // CopyAndUpdateDescriptor FIR interface requires loading
75 // !fir.ref<fir.box> input which is expensive with assumed-rank. It could
76 // be best to add an entry point that takes a non "const" from to cover
77 // this case, but it would be good to indicate to LLVM that from does not
79 if (fir::isBoxAddress(rebox
.getBox().getType()))
80 TODO(loc
, "fir.rebox_assumed_rank codegen with fir.ref<fir.box<>> input");
81 mlir::Value tempDesc
= builder
.createTemporary(loc
, newMaxRankBoxType
);
83 mlir::Type newEleType
= newBoxType
.unwrapInnerType();
84 auto oldBoxType
= mlir::cast
<fir::BaseBoxType
>(
85 fir::unwrapRefType(rebox
.getBox().getType()));
86 auto newDerivedType
= mlir::dyn_cast
<fir::RecordType
>(newEleType
);
87 if (newDerivedType
&& !fir::isPolymorphicType(newBoxType
) &&
88 (fir::isPolymorphicType(oldBoxType
) ||
89 (newEleType
!= oldBoxType
.unwrapInnerType())) &&
90 !fir::isPolymorphicType(newBoxType
)) {
91 newDtype
= builder
.create
<fir::TypeDescOp
>(
92 loc
, mlir::TypeAttr::get(newDerivedType
));
94 newDtype
= builder
.createNullConstant(loc
);
96 mlir::Value newAttribute
= builder
.createIntegerConstant(
97 loc
, builder
.getIntegerType(8), getCFIAttribute(newBoxType
));
99 static_cast<int>(getLowerBoundModifier(rebox
.getLbsModifier()));
100 mlir::Value lowerBoundModifier
= builder
.createIntegerConstant(
101 loc
, builder
.getIntegerType(32), lbsModifierCode
);
102 fir::runtime::genCopyAndUpdateDescriptor(builder
, loc
, tempDesc
,
103 rebox
.getBox(), newDtype
,
104 newAttribute
, lowerBoundModifier
);
106 mlir::Value descValue
= builder
.create
<fir::LoadOp
>(loc
, tempDesc
);
107 mlir::Value castDesc
= builder
.createConvert(loc
, newBoxType
, descValue
);
108 rewriter
.replaceOp(rebox
, castDesc
);
109 return mlir::success();
113 mlir::SymbolTable
*symbolTable
= nullptr;
114 fir::KindMapping kindMap
;
117 class IsAssumedSizeConv
: public mlir::OpRewritePattern
<fir::IsAssumedSizeOp
> {
119 using OpRewritePattern::OpRewritePattern
;
121 IsAssumedSizeConv(mlir::MLIRContext
*context
, mlir::SymbolTable
*symbolTable
,
122 fir::KindMapping kindMap
)
123 : mlir::OpRewritePattern
<fir::IsAssumedSizeOp
>(context
),
124 symbolTable
{symbolTable
}, kindMap
{kindMap
} {};
127 matchAndRewrite(fir::IsAssumedSizeOp isAssumedSizeOp
,
128 mlir::PatternRewriter
&rewriter
) const override
{
129 fir::FirOpBuilder builder
{rewriter
, kindMap
, symbolTable
};
130 mlir::Location loc
= isAssumedSizeOp
.getLoc();
132 fir::runtime::genIsAssumedSize(builder
, loc
, isAssumedSizeOp
.getVal());
133 rewriter
.replaceOp(isAssumedSizeOp
, result
);
134 return mlir::success();
138 mlir::SymbolTable
*symbolTable
= nullptr;
139 fir::KindMapping kindMap
;
142 /// Convert FIR structured control flow ops to CFG ops.
143 class AssumedRankOpConversion
144 : public fir::impl::AssumedRankOpConversionBase
<AssumedRankOpConversion
> {
146 void runOnOperation() override
{
147 auto *context
= &getContext();
148 mlir::ModuleOp mod
= getOperation();
149 mlir::SymbolTable
symbolTable(mod
);
150 fir::KindMapping kindMap
= fir::getKindMapping(mod
);
151 mlir::RewritePatternSet
patterns(context
);
152 patterns
.insert
<ReboxAssumedRankConv
>(context
, &symbolTable
, kindMap
);
153 patterns
.insert
<IsAssumedSizeConv
>(context
, &symbolTable
, kindMap
);
154 mlir::GreedyRewriteConfig config
;
155 config
.enableRegionSimplification
=
156 mlir::GreedySimplifyRegionLevel::Disabled
;
157 (void)applyPatternsGreedily(mod
, std::move(patterns
), config
);