[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / MemRefToLLVM / AllocLikeConversion.cpp
bloba6408391b1330c6fb1db782e6c42dc4e4cc69eeb
1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10 #include "mlir/Analysis/DataLayoutAnalysis.h"
11 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/SymbolTable.h"
15 using namespace mlir;
17 namespace {
18 LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
19 Operation *module, Type indexType) {
20 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
21 if (useGenericFn)
22 return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
24 return LLVM::lookupOrCreateMallocFn(module, indexType);
27 LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
28 Operation *module, Type indexType) {
29 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
31 if (useGenericFn)
32 return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
34 return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
37 } // end namespace
39 Value AllocationOpLLVMLowering::createAligned(
40 ConversionPatternRewriter &rewriter, Location loc, Value input,
41 Value alignment) {
42 Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
43 Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
44 Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
45 Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
46 return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
49 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
50 Location loc, Value allocatedPtr,
51 MemRefType memRefType, Type elementPtrType,
52 const LLVMTypeConverter &typeConverter) {
53 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
54 FailureOr<unsigned> maybeMemrefAddrSpace =
55 typeConverter.getMemRefAddressSpace(memRefType);
56 if (failed(maybeMemrefAddrSpace))
57 return Value();
58 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
59 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
60 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
61 loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
62 allocatedPtr);
63 return allocatedPtr;
66 std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
67 ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
68 Operation *op, Value alignment) const {
69 if (alignment) {
70 // Adjust the allocation size to consider alignment.
71 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
74 MemRefType memRefType = getMemRefResultType(op);
75 // Allocate the underlying buffer.
76 Type elementPtrType = this->getElementPtrType(memRefType);
77 if (!elementPtrType) {
78 emitError(loc, "conversion of memref memory space ")
79 << memRefType.getMemorySpace()
80 << " to integer address space "
81 "failed. Consider adding memory space conversions.";
83 LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
84 getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
85 getIndexType());
86 auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
88 Value allocatedPtr =
89 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
90 elementPtrType, *getTypeConverter());
91 if (!allocatedPtr)
92 return std::make_tuple(Value(), Value());
93 Value alignedPtr = allocatedPtr;
94 if (alignment) {
95 // Compute the aligned pointer.
96 Value allocatedInt =
97 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
98 Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
99 alignedPtr =
100 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
103 return std::make_tuple(allocatedPtr, alignedPtr);
106 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
107 MemRefType memRefType, Operation *op,
108 const DataLayout *defaultLayout) const {
109 const DataLayout *layout = defaultLayout;
110 if (const DataLayoutAnalysis *analysis =
111 getTypeConverter()->getDataLayoutAnalysis()) {
112 layout = &analysis->getAbove(op);
114 Type elementType = memRefType.getElementType();
115 if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
116 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
117 *layout);
118 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
119 return getTypeConverter()->getUnrankedMemRefDescriptorSize(
120 memRefElementType, *layout);
121 return layout->getTypeSize(elementType);
124 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
125 MemRefType type, uint64_t factor, Operation *op,
126 const DataLayout *defaultLayout) const {
127 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
128 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
129 if (type.isDynamicDim(i))
130 continue;
131 sizeDivisor = sizeDivisor * type.getDimSize(i);
133 return sizeDivisor % factor == 0;
136 Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
137 ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
138 Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
139 Value allocAlignment =
140 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
142 MemRefType memRefType = getMemRefResultType(op);
143 // Function aligned_alloc requires size to be a multiple of alignment; we pad
144 // the size to the next multiple if necessary.
145 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
146 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
148 Type elementPtrType = this->getElementPtrType(memRefType);
149 LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
150 getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
151 getIndexType());
152 auto results = rewriter.create<LLVM::CallOp>(
153 loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
155 return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
156 elementPtrType, *getTypeConverter());
159 void AllocLikeOpLLVMLowering::setRequiresNumElements() {
160 requiresNumElements = true;
163 LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
164 Operation *op, ArrayRef<Value> operands,
165 ConversionPatternRewriter &rewriter) const {
166 MemRefType memRefType = getMemRefResultType(op);
167 if (!isConvertibleAndHasIdentityMaps(memRefType))
168 return rewriter.notifyMatchFailure(op, "incompatible memref type");
169 auto loc = op->getLoc();
171 // Get actual sizes of the memref as values: static sizes are constant
172 // values and dynamic sizes are passed to 'alloc' as operands. In case of
173 // zero-dimensional memref, assume a scalar (size 1).
174 SmallVector<Value, 4> sizes;
175 SmallVector<Value, 4> strides;
176 Value size;
178 this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
179 strides, size, !requiresNumElements);
181 // Allocate the underlying buffer.
182 auto [allocatedPtr, alignedPtr] =
183 this->allocateBuffer(rewriter, loc, size, op);
185 if (!allocatedPtr || !alignedPtr)
186 return rewriter.notifyMatchFailure(loc,
187 "underlying buffer allocation failed");
189 // Create the MemRef descriptor.
190 auto memRefDescriptor = this->createMemRefDescriptor(
191 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
193 // Return the final value of the descriptor.
194 rewriter.replaceOp(op, {memRefDescriptor});
195 return success();