1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinAttributes.h"
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
22 ConvertToLLVMPattern::ConvertToLLVMPattern(
23 StringRef rootOpName
, MLIRContext
*context
,
24 const LLVMTypeConverter
&typeConverter
, PatternBenefit benefit
)
25 : ConversionPattern(typeConverter
, rootOpName
, benefit
, context
) {}
27 const LLVMTypeConverter
*ConvertToLLVMPattern::getTypeConverter() const {
28 return static_cast<const LLVMTypeConverter
*>(
29 ConversionPattern::getTypeConverter());
32 LLVM::LLVMDialect
&ConvertToLLVMPattern::getDialect() const {
33 return *getTypeConverter()->getDialect();
36 Type
ConvertToLLVMPattern::getIndexType() const {
37 return getTypeConverter()->getIndexType();
40 Type
ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace
) const {
41 return IntegerType::get(&getTypeConverter()->getContext(),
42 getTypeConverter()->getPointerBitwidth(addressSpace
));
45 Type
ConvertToLLVMPattern::getVoidType() const {
46 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
49 Type
ConvertToLLVMPattern::getVoidPtrType() const {
50 return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext());
53 Value
ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder
&builder
,
57 return builder
.create
<LLVM::ConstantOp
>(loc
, resultType
,
58 builder
.getIndexAttr(value
));
61 Value
ConvertToLLVMPattern::getStridedElementPtr(
62 Location loc
, MemRefType type
, Value memRefDesc
, ValueRange indices
,
63 ConversionPatternRewriter
&rewriter
) const {
65 auto [strides
, offset
] = getStridesAndOffset(type
);
67 MemRefDescriptor
memRefDescriptor(memRefDesc
);
68 // Use a canonical representation of the start address so that later
69 // optimizations have a longer sequence of instructions to CSE.
70 // If we don't do that we would sprinkle the memref.offset in various
71 // position of the different address computations.
73 memRefDescriptor
.bufferPtr(rewriter
, loc
, *getTypeConverter(), type
);
75 Type indexType
= getIndexType();
77 for (int i
= 0, e
= indices
.size(); i
< e
; ++i
) {
78 Value increment
= indices
[i
];
79 if (strides
[i
] != 1) { // Skip if stride is 1.
81 ShapedType::isDynamic(strides
[i
])
82 ? memRefDescriptor
.stride(rewriter
, loc
, i
)
83 : createIndexAttrConstant(rewriter
, loc
, indexType
, strides
[i
]);
84 increment
= rewriter
.create
<LLVM::MulOp
>(loc
, increment
, stride
);
87 index
? rewriter
.create
<LLVM::AddOp
>(loc
, index
, increment
) : increment
;
90 Type elementPtrType
= memRefDescriptor
.getElementPtrType();
91 return index
? rewriter
.create
<LLVM::GEPOp
>(
93 getTypeConverter()->convertType(type
.getElementType()),
98 // Check if the MemRefType `type` is supported by the lowering. We currently
99 // only support memrefs with identity maps.
100 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
101 MemRefType type
) const {
102 if (!typeConverter
->convertType(type
.getElementType()))
104 return type
.getLayout().isIdentity();
107 Type
ConvertToLLVMPattern::getElementPtrType(MemRefType type
) const {
108 auto addressSpace
= getTypeConverter()->getMemRefAddressSpace(type
);
109 if (failed(addressSpace
))
111 return LLVM::LLVMPointerType::get(type
.getContext(), *addressSpace
);
114 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
115 Location loc
, MemRefType memRefType
, ValueRange dynamicSizes
,
116 ConversionPatternRewriter
&rewriter
, SmallVectorImpl
<Value
> &sizes
,
117 SmallVectorImpl
<Value
> &strides
, Value
&size
, bool sizeInBytes
) const {
118 assert(isConvertibleAndHasIdentityMaps(memRefType
) &&
119 "layout maps must have been normalized away");
120 assert(count(memRefType
.getShape(), ShapedType::kDynamic
) ==
121 static_cast<ssize_t
>(dynamicSizes
.size()) &&
122 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
124 sizes
.reserve(memRefType
.getRank());
125 unsigned dynamicIndex
= 0;
126 Type indexType
= getIndexType();
127 for (int64_t size
: memRefType
.getShape()) {
129 size
== ShapedType::kDynamic
130 ? dynamicSizes
[dynamicIndex
++]
131 : createIndexAttrConstant(rewriter
, loc
, indexType
, size
));
134 // Strides: iterate sizes in reverse order and multiply.
136 Value runningStride
= createIndexAttrConstant(rewriter
, loc
, indexType
, 1);
137 strides
.resize(memRefType
.getRank());
138 for (auto i
= memRefType
.getRank(); i
-- > 0;) {
139 strides
[i
] = runningStride
;
141 int64_t staticSize
= memRefType
.getShape()[i
];
142 bool useSizeAsStride
= stride
== 1;
143 if (staticSize
== ShapedType::kDynamic
)
144 stride
= ShapedType::kDynamic
;
145 if (stride
!= ShapedType::kDynamic
)
146 stride
*= staticSize
;
149 runningStride
= sizes
[i
];
150 else if (stride
== ShapedType::kDynamic
)
152 rewriter
.create
<LLVM::MulOp
>(loc
, runningStride
, sizes
[i
]);
154 runningStride
= createIndexAttrConstant(rewriter
, loc
, indexType
, stride
);
157 // Buffer size in bytes.
158 Type elementType
= typeConverter
->convertType(memRefType
.getElementType());
159 auto elementPtrType
= LLVM::LLVMPointerType::get(rewriter
.getContext());
160 Value nullPtr
= rewriter
.create
<LLVM::ZeroOp
>(loc
, elementPtrType
);
161 Value gepPtr
= rewriter
.create
<LLVM::GEPOp
>(
162 loc
, elementPtrType
, elementType
, nullPtr
, runningStride
);
163 size
= rewriter
.create
<LLVM::PtrToIntOp
>(loc
, getIndexType(), gepPtr
);
165 size
= runningStride
;
169 Value
ConvertToLLVMPattern::getSizeInBytes(
170 Location loc
, Type type
, ConversionPatternRewriter
&rewriter
) const {
171 // Compute the size of an individual element. This emits the MLIR equivalent
172 // of the following sizeof(...) implementation in LLVM IR:
173 // %0 = getelementptr %elementType* null, %indexType 1
174 // %1 = ptrtoint %elementType* %0 to %indexType
175 // which is a common pattern of getting the size of a type in bytes.
176 Type llvmType
= typeConverter
->convertType(type
);
177 auto convertedPtrType
= LLVM::LLVMPointerType::get(rewriter
.getContext());
178 auto nullPtr
= rewriter
.create
<LLVM::ZeroOp
>(loc
, convertedPtrType
);
179 auto gep
= rewriter
.create
<LLVM::GEPOp
>(loc
, convertedPtrType
, llvmType
,
180 nullPtr
, ArrayRef
<LLVM::GEPArg
>{1});
181 return rewriter
.create
<LLVM::PtrToIntOp
>(loc
, getIndexType(), gep
);
184 Value
ConvertToLLVMPattern::getNumElements(
185 Location loc
, MemRefType memRefType
, ValueRange dynamicSizes
,
186 ConversionPatternRewriter
&rewriter
) const {
187 assert(count(memRefType
.getShape(), ShapedType::kDynamic
) ==
188 static_cast<ssize_t
>(dynamicSizes
.size()) &&
189 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
191 Type indexType
= getIndexType();
192 Value numElements
= memRefType
.getRank() == 0
193 ? createIndexAttrConstant(rewriter
, loc
, indexType
, 1)
195 unsigned dynamicIndex
= 0;
197 // Compute the total number of memref elements.
198 for (int64_t staticSize
: memRefType
.getShape()) {
201 staticSize
== ShapedType::kDynamic
202 ? dynamicSizes
[dynamicIndex
++]
203 : createIndexAttrConstant(rewriter
, loc
, indexType
, staticSize
);
204 numElements
= rewriter
.create
<LLVM::MulOp
>(loc
, numElements
, size
);
207 staticSize
== ShapedType::kDynamic
208 ? dynamicSizes
[dynamicIndex
++]
209 : createIndexAttrConstant(rewriter
, loc
, indexType
, staticSize
);
215 /// Creates and populates the memref descriptor struct given all its fields.
216 MemRefDescriptor
ConvertToLLVMPattern::createMemRefDescriptor(
217 Location loc
, MemRefType memRefType
, Value allocatedPtr
, Value alignedPtr
,
218 ArrayRef
<Value
> sizes
, ArrayRef
<Value
> strides
,
219 ConversionPatternRewriter
&rewriter
) const {
220 auto structType
= typeConverter
->convertType(memRefType
);
221 auto memRefDescriptor
= MemRefDescriptor::undef(rewriter
, loc
, structType
);
223 // Field 1: Allocated pointer, used for malloc/free.
224 memRefDescriptor
.setAllocatedPtr(rewriter
, loc
, allocatedPtr
);
226 // Field 2: Actual aligned pointer to payload.
227 memRefDescriptor
.setAlignedPtr(rewriter
, loc
, alignedPtr
);
229 // Field 3: Offset in aligned pointer.
230 Type indexType
= getIndexType();
231 memRefDescriptor
.setOffset(
232 rewriter
, loc
, createIndexAttrConstant(rewriter
, loc
, indexType
, 0));
235 for (const auto &en
: llvm::enumerate(sizes
))
236 memRefDescriptor
.setSize(rewriter
, loc
, en
.index(), en
.value());
239 for (const auto &en
: llvm::enumerate(strides
))
240 memRefDescriptor
.setStride(rewriter
, loc
, en
.index(), en
.value());
242 return memRefDescriptor
;
245 LogicalResult
ConvertToLLVMPattern::copyUnrankedDescriptors(
246 OpBuilder
&builder
, Location loc
, TypeRange origTypes
,
247 SmallVectorImpl
<Value
> &operands
, bool toDynamic
) const {
248 assert(origTypes
.size() == operands
.size() &&
249 "expected as may original types as operands");
251 // Find operands of unranked memref type and store them.
252 SmallVector
<UnrankedMemRefDescriptor
> unrankedMemrefs
;
253 SmallVector
<unsigned> unrankedAddressSpaces
;
254 for (unsigned i
= 0, e
= operands
.size(); i
< e
; ++i
) {
255 if (auto memRefType
= dyn_cast
<UnrankedMemRefType
>(origTypes
[i
])) {
256 unrankedMemrefs
.emplace_back(operands
[i
]);
257 FailureOr
<unsigned> addressSpace
=
258 getTypeConverter()->getMemRefAddressSpace(memRefType
);
259 if (failed(addressSpace
))
261 unrankedAddressSpaces
.emplace_back(*addressSpace
);
265 if (unrankedMemrefs
.empty())
268 // Compute allocation sizes.
269 SmallVector
<Value
> sizes
;
270 UnrankedMemRefDescriptor::computeSizes(builder
, loc
, *getTypeConverter(),
271 unrankedMemrefs
, unrankedAddressSpaces
,
274 // Get frequently used types.
275 Type indexType
= getTypeConverter()->getIndexType();
277 // Find the malloc and free, or declare them if necessary.
278 auto module
= builder
.getInsertionPoint()->getParentOfType
<ModuleOp
>();
279 LLVM::LLVMFuncOp freeFunc
, mallocFunc
;
281 mallocFunc
= LLVM::lookupOrCreateMallocFn(module
, indexType
);
283 freeFunc
= LLVM::lookupOrCreateFreeFn(module
);
285 unsigned unrankedMemrefPos
= 0;
286 for (unsigned i
= 0, e
= operands
.size(); i
< e
; ++i
) {
287 Type type
= origTypes
[i
];
288 if (!isa
<UnrankedMemRefType
>(type
))
290 Value allocationSize
= sizes
[unrankedMemrefPos
++];
291 UnrankedMemRefDescriptor
desc(operands
[i
]);
293 // Allocate memory, copy, and free the source if necessary.
296 ? builder
.create
<LLVM::CallOp
>(loc
, mallocFunc
, allocationSize
)
298 : builder
.create
<LLVM::AllocaOp
>(loc
, getVoidPtrType(),
299 IntegerType::get(getContext(), 8),
302 Value source
= desc
.memRefDescPtr(builder
, loc
);
303 builder
.create
<LLVM::MemcpyOp
>(loc
, memory
, source
, allocationSize
, false);
305 builder
.create
<LLVM::CallOp
>(loc
, freeFunc
, source
);
307 // Create a new descriptor. The same descriptor can be returned multiple
308 // times, attempting to modify its pointer can lead to memory leaks
309 // (allocated twice and overwritten) or double frees (the caller does not
310 // know if the descriptor points to the same memory).
311 Type descriptorType
= getTypeConverter()->convertType(type
);
315 UnrankedMemRefDescriptor::undef(builder
, loc
, descriptorType
);
316 Value rank
= desc
.rank(builder
, loc
);
317 updatedDesc
.setRank(builder
, loc
, rank
);
318 updatedDesc
.setMemRefDescPtr(builder
, loc
, memory
);
320 operands
[i
] = updatedDesc
;
326 //===----------------------------------------------------------------------===//
328 //===----------------------------------------------------------------------===//
330 void LLVM::detail::setNativeProperties(Operation
*op
,
331 IntegerOverflowFlags overflowFlags
) {
332 if (auto iface
= dyn_cast
<IntegerOverflowFlagsInterface
>(op
))
333 iface
.setOverflowFlags(overflowFlags
);
336 /// Replaces the given operation "op" with a new operation of type "targetOp"
337 /// and given operands.
338 LogicalResult
LLVM::detail::oneToOneRewrite(
339 Operation
*op
, StringRef targetOp
, ValueRange operands
,
340 ArrayRef
<NamedAttribute
> targetAttrs
,
341 const LLVMTypeConverter
&typeConverter
, ConversionPatternRewriter
&rewriter
,
342 IntegerOverflowFlags overflowFlags
) {
343 unsigned numResults
= op
->getNumResults();
345 SmallVector
<Type
> resultTypes
;
346 if (numResults
!= 0) {
347 resultTypes
.push_back(
348 typeConverter
.packOperationResults(op
->getResultTypes()));
349 if (!resultTypes
.back())
353 // Create the operation through state since we don't know its C++ type.
355 rewriter
.create(op
->getLoc(), rewriter
.getStringAttr(targetOp
), operands
,
356 resultTypes
, targetAttrs
);
358 setNativeProperties(newOp
, overflowFlags
);
360 // If the operation produced 0 or 1 result, return them immediately.
362 return rewriter
.eraseOp(op
), success();
364 return rewriter
.replaceOp(op
, newOp
->getResult(0)), success();
366 // Otherwise, it had been converted to an operation producing a structure.
367 // Extract individual results from the structure and return them as list.
368 SmallVector
<Value
, 4> results
;
369 results
.reserve(numResults
);
370 for (unsigned i
= 0; i
< numResults
; ++i
) {
371 results
.push_back(rewriter
.create
<LLVM::ExtractValueOp
>(
372 op
->getLoc(), newOp
->getResult(0), i
));
374 rewriter
.replaceOp(op
, results
);