[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / LLVMCommon / Pattern.cpp
blobd551506485a454f37ca7e18c46e6fb1b29fac88a
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinAttributes.h"
16 using namespace mlir;
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
22 ConvertToLLVMPattern::ConvertToLLVMPattern(
23 StringRef rootOpName, MLIRContext *context,
24 const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
27 const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
28 return static_cast<const LLVMTypeConverter *>(
29 ConversionPattern::getTypeConverter());
32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33 return *getTypeConverter()->getDialect();
36 Type ConvertToLLVMPattern::getIndexType() const {
37 return getTypeConverter()->getIndexType();
40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
41 return IntegerType::get(&getTypeConverter()->getContext(),
42 getTypeConverter()->getPointerBitwidth(addressSpace));
45 Type ConvertToLLVMPattern::getVoidType() const {
46 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
49 Type ConvertToLLVMPattern::getVoidPtrType() const {
50 return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext());
53 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
54 Location loc,
55 Type resultType,
56 int64_t value) {
57 return builder.create<LLVM::ConstantOp>(loc, resultType,
58 builder.getIndexAttr(value));
61 Value ConvertToLLVMPattern::getStridedElementPtr(
62 Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
63 ConversionPatternRewriter &rewriter) const {
65 auto [strides, offset] = getStridesAndOffset(type);
67 MemRefDescriptor memRefDescriptor(memRefDesc);
68 // Use a canonical representation of the start address so that later
69 // optimizations have a longer sequence of instructions to CSE.
70 // If we don't do that we would sprinkle the memref.offset in various
71 // position of the different address computations.
72 Value base =
73 memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
75 Type indexType = getIndexType();
76 Value index;
77 for (int i = 0, e = indices.size(); i < e; ++i) {
78 Value increment = indices[i];
79 if (strides[i] != 1) { // Skip if stride is 1.
80 Value stride =
81 ShapedType::isDynamic(strides[i])
82 ? memRefDescriptor.stride(rewriter, loc, i)
83 : createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
84 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
86 index =
87 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
90 Type elementPtrType = memRefDescriptor.getElementPtrType();
91 return index ? rewriter.create<LLVM::GEPOp>(
92 loc, elementPtrType,
93 getTypeConverter()->convertType(type.getElementType()),
94 base, index)
95 : base;
98 // Check if the MemRefType `type` is supported by the lowering. We currently
99 // only support memrefs with identity maps.
100 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
101 MemRefType type) const {
102 if (!typeConverter->convertType(type.getElementType()))
103 return false;
104 return type.getLayout().isIdentity();
107 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
108 auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
109 if (failed(addressSpace))
110 return {};
111 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
114 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
115 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
116 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
117 SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
118 assert(isConvertibleAndHasIdentityMaps(memRefType) &&
119 "layout maps must have been normalized away");
120 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
121 static_cast<ssize_t>(dynamicSizes.size()) &&
122 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
124 sizes.reserve(memRefType.getRank());
125 unsigned dynamicIndex = 0;
126 Type indexType = getIndexType();
127 for (int64_t size : memRefType.getShape()) {
128 sizes.push_back(
129 size == ShapedType::kDynamic
130 ? dynamicSizes[dynamicIndex++]
131 : createIndexAttrConstant(rewriter, loc, indexType, size));
134 // Strides: iterate sizes in reverse order and multiply.
135 int64_t stride = 1;
136 Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
137 strides.resize(memRefType.getRank());
138 for (auto i = memRefType.getRank(); i-- > 0;) {
139 strides[i] = runningStride;
141 int64_t staticSize = memRefType.getShape()[i];
142 bool useSizeAsStride = stride == 1;
143 if (staticSize == ShapedType::kDynamic)
144 stride = ShapedType::kDynamic;
145 if (stride != ShapedType::kDynamic)
146 stride *= staticSize;
148 if (useSizeAsStride)
149 runningStride = sizes[i];
150 else if (stride == ShapedType::kDynamic)
151 runningStride =
152 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
153 else
154 runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
156 if (sizeInBytes) {
157 // Buffer size in bytes.
158 Type elementType = typeConverter->convertType(memRefType.getElementType());
159 auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
160 Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
161 Value gepPtr = rewriter.create<LLVM::GEPOp>(
162 loc, elementPtrType, elementType, nullPtr, runningStride);
163 size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
164 } else {
165 size = runningStride;
169 Value ConvertToLLVMPattern::getSizeInBytes(
170 Location loc, Type type, ConversionPatternRewriter &rewriter) const {
171 // Compute the size of an individual element. This emits the MLIR equivalent
172 // of the following sizeof(...) implementation in LLVM IR:
173 // %0 = getelementptr %elementType* null, %indexType 1
174 // %1 = ptrtoint %elementType* %0 to %indexType
175 // which is a common pattern of getting the size of a type in bytes.
176 Type llvmType = typeConverter->convertType(type);
177 auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
178 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
179 auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
180 nullPtr, ArrayRef<LLVM::GEPArg>{1});
181 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
184 Value ConvertToLLVMPattern::getNumElements(
185 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
186 ConversionPatternRewriter &rewriter) const {
187 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
188 static_cast<ssize_t>(dynamicSizes.size()) &&
189 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
191 Type indexType = getIndexType();
192 Value numElements = memRefType.getRank() == 0
193 ? createIndexAttrConstant(rewriter, loc, indexType, 1)
194 : nullptr;
195 unsigned dynamicIndex = 0;
197 // Compute the total number of memref elements.
198 for (int64_t staticSize : memRefType.getShape()) {
199 if (numElements) {
200 Value size =
201 staticSize == ShapedType::kDynamic
202 ? dynamicSizes[dynamicIndex++]
203 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
204 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
205 } else {
206 numElements =
207 staticSize == ShapedType::kDynamic
208 ? dynamicSizes[dynamicIndex++]
209 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
212 return numElements;
215 /// Creates and populates the memref descriptor struct given all its fields.
216 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
217 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
218 ArrayRef<Value> sizes, ArrayRef<Value> strides,
219 ConversionPatternRewriter &rewriter) const {
220 auto structType = typeConverter->convertType(memRefType);
221 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
223 // Field 1: Allocated pointer, used for malloc/free.
224 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
226 // Field 2: Actual aligned pointer to payload.
227 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
229 // Field 3: Offset in aligned pointer.
230 Type indexType = getIndexType();
231 memRefDescriptor.setOffset(
232 rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
234 // Fields 4: Sizes.
235 for (const auto &en : llvm::enumerate(sizes))
236 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
238 // Field 5: Strides.
239 for (const auto &en : llvm::enumerate(strides))
240 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
242 return memRefDescriptor;
245 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
246 OpBuilder &builder, Location loc, TypeRange origTypes,
247 SmallVectorImpl<Value> &operands, bool toDynamic) const {
248 assert(origTypes.size() == operands.size() &&
249 "expected as may original types as operands");
251 // Find operands of unranked memref type and store them.
252 SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
253 SmallVector<unsigned> unrankedAddressSpaces;
254 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
255 if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
256 unrankedMemrefs.emplace_back(operands[i]);
257 FailureOr<unsigned> addressSpace =
258 getTypeConverter()->getMemRefAddressSpace(memRefType);
259 if (failed(addressSpace))
260 return failure();
261 unrankedAddressSpaces.emplace_back(*addressSpace);
265 if (unrankedMemrefs.empty())
266 return success();
268 // Compute allocation sizes.
269 SmallVector<Value> sizes;
270 UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
271 unrankedMemrefs, unrankedAddressSpaces,
272 sizes);
274 // Get frequently used types.
275 Type indexType = getTypeConverter()->getIndexType();
277 // Find the malloc and free, or declare them if necessary.
278 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
279 LLVM::LLVMFuncOp freeFunc, mallocFunc;
280 if (toDynamic)
281 mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
282 if (!toDynamic)
283 freeFunc = LLVM::lookupOrCreateFreeFn(module);
285 unsigned unrankedMemrefPos = 0;
286 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
287 Type type = origTypes[i];
288 if (!isa<UnrankedMemRefType>(type))
289 continue;
290 Value allocationSize = sizes[unrankedMemrefPos++];
291 UnrankedMemRefDescriptor desc(operands[i]);
293 // Allocate memory, copy, and free the source if necessary.
294 Value memory =
295 toDynamic
296 ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
297 .getResult()
298 : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
299 IntegerType::get(getContext(), 8),
300 allocationSize,
301 /*alignment=*/0);
302 Value source = desc.memRefDescPtr(builder, loc);
303 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
304 if (!toDynamic)
305 builder.create<LLVM::CallOp>(loc, freeFunc, source);
307 // Create a new descriptor. The same descriptor can be returned multiple
308 // times, attempting to modify its pointer can lead to memory leaks
309 // (allocated twice and overwritten) or double frees (the caller does not
310 // know if the descriptor points to the same memory).
311 Type descriptorType = getTypeConverter()->convertType(type);
312 if (!descriptorType)
313 return failure();
314 auto updatedDesc =
315 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
316 Value rank = desc.rank(builder, loc);
317 updatedDesc.setRank(builder, loc, rank);
318 updatedDesc.setMemRefDescPtr(builder, loc, memory);
320 operands[i] = updatedDesc;
323 return success();
326 //===----------------------------------------------------------------------===//
327 // Detail methods
328 //===----------------------------------------------------------------------===//
330 void LLVM::detail::setNativeProperties(Operation *op,
331 IntegerOverflowFlags overflowFlags) {
332 if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
333 iface.setOverflowFlags(overflowFlags);
336 /// Replaces the given operation "op" with a new operation of type "targetOp"
337 /// and given operands.
338 LogicalResult LLVM::detail::oneToOneRewrite(
339 Operation *op, StringRef targetOp, ValueRange operands,
340 ArrayRef<NamedAttribute> targetAttrs,
341 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
342 IntegerOverflowFlags overflowFlags) {
343 unsigned numResults = op->getNumResults();
345 SmallVector<Type> resultTypes;
346 if (numResults != 0) {
347 resultTypes.push_back(
348 typeConverter.packOperationResults(op->getResultTypes()));
349 if (!resultTypes.back())
350 return failure();
353 // Create the operation through state since we don't know its C++ type.
354 Operation *newOp =
355 rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
356 resultTypes, targetAttrs);
358 setNativeProperties(newOp, overflowFlags);
360 // If the operation produced 0 or 1 result, return them immediately.
361 if (numResults == 0)
362 return rewriter.eraseOp(op), success();
363 if (numResults == 1)
364 return rewriter.replaceOp(op, newOp->getResult(0)), success();
366 // Otherwise, it had been converted to an operation producing a structure.
367 // Extract individual results from the structure and return them as list.
368 SmallVector<Value, 4> results;
369 results.reserve(numResults);
370 for (unsigned i = 0; i < numResults; ++i) {
371 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
372 op->getLoc(), newOp->getResult(0), i));
374 rewriter.replaceOp(op, results);
375 return success();