1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinAttributes.h"
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
22 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName
,
24 LLVMTypeConverter
&typeConverter
,
25 PatternBenefit benefit
)
26 : ConversionPattern(typeConverter
, rootOpName
, benefit
, context
) {}
28 LLVMTypeConverter
*ConvertToLLVMPattern::getTypeConverter() const {
29 return static_cast<LLVMTypeConverter
*>(
30 ConversionPattern::getTypeConverter());
33 LLVM::LLVMDialect
&ConvertToLLVMPattern::getDialect() const {
34 return *getTypeConverter()->getDialect();
37 Type
ConvertToLLVMPattern::getIndexType() const {
38 return getTypeConverter()->getIndexType();
41 Type
ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace
) const {
42 return IntegerType::get(&getTypeConverter()->getContext(),
43 getTypeConverter()->getPointerBitwidth(addressSpace
));
46 Type
ConvertToLLVMPattern::getVoidType() const {
47 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
50 Type
ConvertToLLVMPattern::getVoidPtrType() const {
51 return LLVM::LLVMPointerType::get(
52 IntegerType::get(&getTypeConverter()->getContext(), 8));
55 Value
ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder
&builder
,
59 return builder
.create
<LLVM::ConstantOp
>(loc
, resultType
,
60 builder
.getIndexAttr(value
));
63 Value
ConvertToLLVMPattern::createIndexConstant(
64 ConversionPatternRewriter
&builder
, Location loc
, uint64_t value
) const {
65 return createIndexAttrConstant(builder
, loc
, getIndexType(), value
);
68 Value
ConvertToLLVMPattern::getStridedElementPtr(
69 Location loc
, MemRefType type
, Value memRefDesc
, ValueRange indices
,
70 ConversionPatternRewriter
&rewriter
) const {
73 SmallVector
<int64_t, 4> strides
;
74 auto successStrides
= getStridesAndOffset(type
, strides
, offset
);
75 assert(succeeded(successStrides
) && "unexpected non-strided memref");
78 MemRefDescriptor
memRefDescriptor(memRefDesc
);
79 Value base
= memRefDescriptor
.alignedPtr(rewriter
, loc
);
82 if (offset
!= 0) // Skip if offset is zero.
83 index
= ShapedType::isDynamic(offset
)
84 ? memRefDescriptor
.offset(rewriter
, loc
)
85 : createIndexConstant(rewriter
, loc
, offset
);
87 for (int i
= 0, e
= indices
.size(); i
< e
; ++i
) {
88 Value increment
= indices
[i
];
89 if (strides
[i
] != 1) { // Skip if stride is 1.
90 Value stride
= ShapedType::isDynamic(strides
[i
])
91 ? memRefDescriptor
.stride(rewriter
, loc
, i
)
92 : createIndexConstant(rewriter
, loc
, strides
[i
]);
93 increment
= rewriter
.create
<LLVM::MulOp
>(loc
, increment
, stride
);
96 index
? rewriter
.create
<LLVM::AddOp
>(loc
, index
, increment
) : increment
;
99 Type elementPtrType
= memRefDescriptor
.getElementPtrType();
100 return index
? rewriter
.create
<LLVM::GEPOp
>(loc
, elementPtrType
, base
, index
)
104 // Check if the MemRefType `type` is supported by the lowering. We currently
105 // only support memrefs with identity maps.
106 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
107 MemRefType type
) const {
108 if (!typeConverter
->convertType(type
.getElementType()))
110 return type
.getLayout().isIdentity();
113 Type
ConvertToLLVMPattern::getElementPtrType(MemRefType type
) const {
114 auto elementType
= type
.getElementType();
115 auto structElementType
= typeConverter
->convertType(elementType
);
116 return LLVM::LLVMPointerType::get(structElementType
,
117 type
.getMemorySpaceAsInt());
120 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
121 Location loc
, MemRefType memRefType
, ValueRange dynamicSizes
,
122 ConversionPatternRewriter
&rewriter
, SmallVectorImpl
<Value
> &sizes
,
123 SmallVectorImpl
<Value
> &strides
, Value
&sizeBytes
) const {
124 assert(isConvertibleAndHasIdentityMaps(memRefType
) &&
125 "layout maps must have been normalized away");
126 assert(count(memRefType
.getShape(), ShapedType::kDynamic
) ==
127 static_cast<ssize_t
>(dynamicSizes
.size()) &&
128 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
130 sizes
.reserve(memRefType
.getRank());
131 unsigned dynamicIndex
= 0;
132 for (int64_t size
: memRefType
.getShape()) {
133 sizes
.push_back(size
== ShapedType::kDynamic
134 ? dynamicSizes
[dynamicIndex
++]
135 : createIndexConstant(rewriter
, loc
, size
));
138 // Strides: iterate sizes in reverse order and multiply.
140 Value runningStride
= createIndexConstant(rewriter
, loc
, 1);
141 strides
.resize(memRefType
.getRank());
142 for (auto i
= memRefType
.getRank(); i
-- > 0;) {
143 strides
[i
] = runningStride
;
145 int64_t size
= memRefType
.getShape()[i
];
148 bool useSizeAsStride
= stride
== 1;
149 if (size
== ShapedType::kDynamic
)
150 stride
= ShapedType::kDynamic
;
151 if (stride
!= ShapedType::kDynamic
)
155 runningStride
= sizes
[i
];
156 else if (stride
== ShapedType::kDynamic
)
158 rewriter
.create
<LLVM::MulOp
>(loc
, runningStride
, sizes
[i
]);
160 runningStride
= createIndexConstant(rewriter
, loc
, stride
);
163 // Buffer size in bytes.
164 Type elementPtrType
= getElementPtrType(memRefType
);
165 Value nullPtr
= rewriter
.create
<LLVM::NullOp
>(loc
, elementPtrType
);
167 rewriter
.create
<LLVM::GEPOp
>(loc
, elementPtrType
, nullPtr
, runningStride
);
168 sizeBytes
= rewriter
.create
<LLVM::PtrToIntOp
>(loc
, getIndexType(), gepPtr
);
171 Value
ConvertToLLVMPattern::getSizeInBytes(
172 Location loc
, Type type
, ConversionPatternRewriter
&rewriter
) const {
173 // Compute the size of an individual element. This emits the MLIR equivalent
174 // of the following sizeof(...) implementation in LLVM IR:
175 // %0 = getelementptr %elementType* null, %indexType 1
176 // %1 = ptrtoint %elementType* %0 to %indexType
177 // which is a common pattern of getting the size of a type in bytes.
178 auto convertedPtrType
=
179 LLVM::LLVMPointerType::get(typeConverter
->convertType(type
));
180 auto nullPtr
= rewriter
.create
<LLVM::NullOp
>(loc
, convertedPtrType
);
181 auto gep
= rewriter
.create
<LLVM::GEPOp
>(loc
, convertedPtrType
, nullPtr
,
182 ArrayRef
<LLVM::GEPArg
>{1});
183 return rewriter
.create
<LLVM::PtrToIntOp
>(loc
, getIndexType(), gep
);
186 Value
ConvertToLLVMPattern::getNumElements(
187 Location loc
, ArrayRef
<Value
> shape
,
188 ConversionPatternRewriter
&rewriter
) const {
189 // Compute the total number of memref elements.
191 shape
.empty() ? createIndexConstant(rewriter
, loc
, 1) : shape
.front();
192 for (unsigned i
= 1, e
= shape
.size(); i
< e
; ++i
)
193 numElements
= rewriter
.create
<LLVM::MulOp
>(loc
, numElements
, shape
[i
]);
197 /// Creates and populates the memref descriptor struct given all its fields.
198 MemRefDescriptor
ConvertToLLVMPattern::createMemRefDescriptor(
199 Location loc
, MemRefType memRefType
, Value allocatedPtr
, Value alignedPtr
,
200 ArrayRef
<Value
> sizes
, ArrayRef
<Value
> strides
,
201 ConversionPatternRewriter
&rewriter
) const {
202 auto structType
= typeConverter
->convertType(memRefType
);
203 auto memRefDescriptor
= MemRefDescriptor::undef(rewriter
, loc
, structType
);
205 // Field 1: Allocated pointer, used for malloc/free.
206 memRefDescriptor
.setAllocatedPtr(rewriter
, loc
, allocatedPtr
);
208 // Field 2: Actual aligned pointer to payload.
209 memRefDescriptor
.setAlignedPtr(rewriter
, loc
, alignedPtr
);
211 // Field 3: Offset in aligned pointer.
212 memRefDescriptor
.setOffset(rewriter
, loc
,
213 createIndexConstant(rewriter
, loc
, 0));
216 for (const auto &en
: llvm::enumerate(sizes
))
217 memRefDescriptor
.setSize(rewriter
, loc
, en
.index(), en
.value());
220 for (const auto &en
: llvm::enumerate(strides
))
221 memRefDescriptor
.setStride(rewriter
, loc
, en
.index(), en
.value());
223 return memRefDescriptor
;
226 LogicalResult
ConvertToLLVMPattern::copyUnrankedDescriptors(
227 OpBuilder
&builder
, Location loc
, TypeRange origTypes
,
228 SmallVectorImpl
<Value
> &operands
, bool toDynamic
) const {
229 assert(origTypes
.size() == operands
.size() &&
230 "expected as may original types as operands");
232 // Find operands of unranked memref type and store them.
233 SmallVector
<UnrankedMemRefDescriptor
, 4> unrankedMemrefs
;
234 for (unsigned i
= 0, e
= operands
.size(); i
< e
; ++i
)
235 if (origTypes
[i
].isa
<UnrankedMemRefType
>())
236 unrankedMemrefs
.emplace_back(operands
[i
]);
238 if (unrankedMemrefs
.empty())
241 // Compute allocation sizes.
242 SmallVector
<Value
, 4> sizes
;
243 UnrankedMemRefDescriptor::computeSizes(builder
, loc
, *getTypeConverter(),
244 unrankedMemrefs
, sizes
);
246 // Get frequently used types.
247 MLIRContext
*context
= builder
.getContext();
248 Type voidPtrType
= LLVM::LLVMPointerType::get(IntegerType::get(context
, 8));
249 auto i1Type
= IntegerType::get(context
, 1);
250 Type indexType
= getTypeConverter()->getIndexType();
252 // Find the malloc and free, or declare them if necessary.
253 auto module
= builder
.getInsertionPoint()->getParentOfType
<ModuleOp
>();
254 LLVM::LLVMFuncOp freeFunc
, mallocFunc
;
256 mallocFunc
= LLVM::lookupOrCreateMallocFn(module
, indexType
);
258 freeFunc
= LLVM::lookupOrCreateFreeFn(module
);
260 // Initialize shared constants.
262 builder
.create
<LLVM::ConstantOp
>(loc
, i1Type
, builder
.getBoolAttr(false));
264 unsigned unrankedMemrefPos
= 0;
265 for (unsigned i
= 0, e
= operands
.size(); i
< e
; ++i
) {
266 Type type
= origTypes
[i
];
267 if (!type
.isa
<UnrankedMemRefType
>())
269 Value allocationSize
= sizes
[unrankedMemrefPos
++];
270 UnrankedMemRefDescriptor
desc(operands
[i
]);
272 // Allocate memory, copy, and free the source if necessary.
275 ? builder
.create
<LLVM::CallOp
>(loc
, mallocFunc
, allocationSize
)
277 : builder
.create
<LLVM::AllocaOp
>(loc
, voidPtrType
, allocationSize
,
279 Value source
= desc
.memRefDescPtr(builder
, loc
);
280 builder
.create
<LLVM::MemcpyOp
>(loc
, memory
, source
, allocationSize
, zero
);
282 builder
.create
<LLVM::CallOp
>(loc
, freeFunc
, source
);
284 // Create a new descriptor. The same descriptor can be returned multiple
285 // times, attempting to modify its pointer can lead to memory leaks
286 // (allocated twice and overwritten) or double frees (the caller does not
287 // know if the descriptor points to the same memory).
288 Type descriptorType
= getTypeConverter()->convertType(type
);
292 UnrankedMemRefDescriptor::undef(builder
, loc
, descriptorType
);
293 Value rank
= desc
.rank(builder
, loc
);
294 updatedDesc
.setRank(builder
, loc
, rank
);
295 updatedDesc
.setMemRefDescPtr(builder
, loc
, memory
);
297 operands
[i
] = updatedDesc
;
303 //===----------------------------------------------------------------------===//
305 //===----------------------------------------------------------------------===//
307 /// Replaces the given operation "op" with a new operation of type "targetOp"
308 /// and given operands.
309 LogicalResult
LLVM::detail::oneToOneRewrite(
310 Operation
*op
, StringRef targetOp
, ValueRange operands
,
311 ArrayRef
<NamedAttribute
> targetAttrs
, LLVMTypeConverter
&typeConverter
,
312 ConversionPatternRewriter
&rewriter
) {
313 unsigned numResults
= op
->getNumResults();
315 SmallVector
<Type
> resultTypes
;
316 if (numResults
!= 0) {
317 resultTypes
.push_back(
318 typeConverter
.packFunctionResults(op
->getResultTypes()));
319 if (!resultTypes
.back())
323 // Create the operation through state since we don't know its C++ type.
325 rewriter
.create(op
->getLoc(), rewriter
.getStringAttr(targetOp
), operands
,
326 resultTypes
, targetAttrs
);
328 // If the operation produced 0 or 1 result, return them immediately.
330 return rewriter
.eraseOp(op
), success();
332 return rewriter
.replaceOp(op
, newOp
->getResult(0)), success();
334 // Otherwise, it had been converted to an operation producing a structure.
335 // Extract individual results from the structure and return them as list.
336 SmallVector
<Value
, 4> results
;
337 results
.reserve(numResults
);
338 for (unsigned i
= 0; i
< numResults
; ++i
) {
339 results
.push_back(rewriter
.create
<LLVM::ExtractValueOp
>(
340 op
->getLoc(), newOp
->getResult(0), i
));
342 rewriter
.replaceOp(op
, results
);