1 //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert Bufferization dialect to MemRef
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
27 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF
28 #include "mlir/Conversion/Passes.h.inc"
34 /// The CloneOpConversion transforms all bufferization clone operations into
35 /// memref alloc and memref copy operations. In the dynamic-shape case, it also
36 /// emits additional dim and constant operations to determine the shape. This
37 /// conversion does not resolve memory leaks if it is used alone.
38 struct CloneOpConversion
: public OpConversionPattern
<bufferization::CloneOp
> {
39 using OpConversionPattern
<bufferization::CloneOp
>::OpConversionPattern
;
42 matchAndRewrite(bufferization::CloneOp op
, OpAdaptor adaptor
,
43 ConversionPatternRewriter
&rewriter
) const override
{
44 Location loc
= op
->getLoc();
46 Type type
= op
.getType();
49 if (auto unrankedType
= dyn_cast
<UnrankedMemRefType
>(type
)) {
51 Value zero
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
52 Value one
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
54 // Dynamically evaluate the size and shape of the unranked memref
55 Value rank
= rewriter
.create
<memref::RankOp
>(loc
, op
.getInput());
56 MemRefType allocType
=
57 MemRefType::get({ShapedType::kDynamic
}, rewriter
.getIndexType());
58 Value shape
= rewriter
.create
<memref::AllocaOp
>(loc
, allocType
, rank
);
60 // Create a loop to query dimension sizes, store them as a shape, and
61 // compute the total size of the memref
62 auto loopBody
= [&](OpBuilder
&builder
, Location loc
, Value i
,
64 auto acc
= args
.front();
65 auto dim
= rewriter
.create
<memref::DimOp
>(loc
, op
.getInput(), i
);
67 rewriter
.create
<memref::StoreOp
>(loc
, dim
, shape
, i
);
68 acc
= rewriter
.create
<arith::MulIOp
>(loc
, acc
, dim
);
70 rewriter
.create
<scf::YieldOp
>(loc
, acc
);
73 .create
<scf::ForOp
>(loc
, zero
, rank
, one
, ValueRange(one
),
77 MemRefType memrefType
= MemRefType::get({ShapedType::kDynamic
},
78 unrankedType
.getElementType());
80 // Allocate new memref with 1D dynamic shape, then reshape into the
81 // shape of the original unranked memref
82 alloc
= rewriter
.create
<memref::AllocOp
>(loc
, memrefType
, size
);
84 rewriter
.create
<memref::ReshapeOp
>(loc
, unrankedType
, alloc
, shape
);
86 MemRefType memrefType
= cast
<MemRefType
>(type
);
87 MemRefLayoutAttrInterface layout
;
89 MemRefType::get(memrefType
.getShape(), memrefType
.getElementType(),
90 layout
, memrefType
.getMemorySpace());
91 // Since this implementation always allocates, certain result types of
92 // the clone op cannot be lowered.
93 if (!memref::CastOp::areCastCompatible({allocType
}, {memrefType
}))
96 // Transform a clone operation into alloc + copy operation and pay
97 // attention to the shape dimensions.
98 SmallVector
<Value
, 4> dynamicOperands
;
99 for (int i
= 0; i
< memrefType
.getRank(); ++i
) {
100 if (!memrefType
.isDynamicDim(i
))
102 Value dim
= rewriter
.createOrFold
<memref::DimOp
>(loc
, op
.getInput(), i
);
103 dynamicOperands
.push_back(dim
);
106 // Allocate a memref with identity layout.
107 alloc
= rewriter
.create
<memref::AllocOp
>(loc
, allocType
, dynamicOperands
);
108 // Cast the allocation to the specified type if needed.
109 if (memrefType
!= allocType
)
111 rewriter
.create
<memref::CastOp
>(op
->getLoc(), memrefType
, alloc
);
114 rewriter
.replaceOp(op
, alloc
);
115 rewriter
.create
<memref::CopyOp
>(loc
, op
.getInput(), alloc
);
123 struct BufferizationToMemRefPass
124 : public impl::ConvertBufferizationToMemRefBase
<BufferizationToMemRefPass
> {
125 BufferizationToMemRefPass() = default;
127 void runOnOperation() override
{
128 if (!isa
<ModuleOp
, FunctionOpInterface
>(getOperation())) {
129 emitError(getOperation()->getLoc(),
130 "root operation must be a builtin.module or a function");
135 bufferization::DeallocHelperMap deallocHelperFuncMap
;
136 if (auto module
= dyn_cast
<ModuleOp
>(getOperation())) {
137 OpBuilder builder
= OpBuilder::atBlockBegin(module
.getBody());
139 // Build dealloc helper function if there are deallocs.
140 getOperation()->walk([&](bufferization::DeallocOp deallocOp
) {
141 Operation
*symtableOp
=
142 deallocOp
->getParentWithTrait
<OpTrait::SymbolTable
>();
143 if (deallocOp
.getMemrefs().size() > 1 &&
144 !deallocHelperFuncMap
.contains(symtableOp
)) {
145 SymbolTable
symbolTable(symtableOp
);
146 func::FuncOp helperFuncOp
=
147 bufferization::buildDeallocationLibraryFunction(
148 builder
, getOperation()->getLoc(), symbolTable
);
149 deallocHelperFuncMap
[symtableOp
] = helperFuncOp
;
154 RewritePatternSet
patterns(&getContext());
155 patterns
.add
<CloneOpConversion
>(patterns
.getContext());
156 bufferization::populateBufferizationDeallocLoweringPattern(
157 patterns
, deallocHelperFuncMap
);
159 ConversionTarget
target(getContext());
160 target
.addLegalDialect
<memref::MemRefDialect
, arith::ArithDialect
,
161 scf::SCFDialect
, func::FuncDialect
>();
162 target
.addIllegalDialect
<bufferization::BufferizationDialect
>();
164 if (failed(applyPartialConversion(getOperation(), target
,
165 std::move(patterns
))))
171 std::unique_ptr
<Pass
> mlir::createBufferizationToMemRefPass() {
172 return std::make_unique
<BufferizationToMemRefPass
>();