1 //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10 #include "mlir/Analysis/DataLayoutAnalysis.h"
11 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/SymbolTable.h"
18 LLVM::LLVMFuncOp
getNotalignedAllocFn(const LLVMTypeConverter
*typeConverter
,
19 Operation
*module
, Type indexType
) {
20 bool useGenericFn
= typeConverter
->getOptions().useGenericFunctions
;
22 return LLVM::lookupOrCreateGenericAllocFn(module
, indexType
);
24 return LLVM::lookupOrCreateMallocFn(module
, indexType
);
27 LLVM::LLVMFuncOp
getAlignedAllocFn(const LLVMTypeConverter
*typeConverter
,
28 Operation
*module
, Type indexType
) {
29 bool useGenericFn
= typeConverter
->getOptions().useGenericFunctions
;
32 return LLVM::lookupOrCreateGenericAlignedAllocFn(module
, indexType
);
34 return LLVM::lookupOrCreateAlignedAllocFn(module
, indexType
);
39 Value
AllocationOpLLVMLowering::createAligned(
40 ConversionPatternRewriter
&rewriter
, Location loc
, Value input
,
42 Value one
= createIndexAttrConstant(rewriter
, loc
, alignment
.getType(), 1);
43 Value bump
= rewriter
.create
<LLVM::SubOp
>(loc
, alignment
, one
);
44 Value bumped
= rewriter
.create
<LLVM::AddOp
>(loc
, input
, bump
);
45 Value mod
= rewriter
.create
<LLVM::URemOp
>(loc
, bumped
, alignment
);
46 return rewriter
.create
<LLVM::SubOp
>(loc
, bumped
, mod
);
49 static Value
castAllocFuncResult(ConversionPatternRewriter
&rewriter
,
50 Location loc
, Value allocatedPtr
,
51 MemRefType memRefType
, Type elementPtrType
,
52 const LLVMTypeConverter
&typeConverter
) {
53 auto allocatedPtrTy
= cast
<LLVM::LLVMPointerType
>(allocatedPtr
.getType());
54 FailureOr
<unsigned> maybeMemrefAddrSpace
=
55 typeConverter
.getMemRefAddressSpace(memRefType
);
56 if (failed(maybeMemrefAddrSpace
))
58 unsigned memrefAddrSpace
= *maybeMemrefAddrSpace
;
59 if (allocatedPtrTy
.getAddressSpace() != memrefAddrSpace
)
60 allocatedPtr
= rewriter
.create
<LLVM::AddrSpaceCastOp
>(
61 loc
, LLVM::LLVMPointerType::get(rewriter
.getContext(), memrefAddrSpace
),
66 std::tuple
<Value
, Value
> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
67 ConversionPatternRewriter
&rewriter
, Location loc
, Value sizeBytes
,
68 Operation
*op
, Value alignment
) const {
70 // Adjust the allocation size to consider alignment.
71 sizeBytes
= rewriter
.create
<LLVM::AddOp
>(loc
, sizeBytes
, alignment
);
74 MemRefType memRefType
= getMemRefResultType(op
);
75 // Allocate the underlying buffer.
76 Type elementPtrType
= this->getElementPtrType(memRefType
);
77 if (!elementPtrType
) {
78 emitError(loc
, "conversion of memref memory space ")
79 << memRefType
.getMemorySpace()
80 << " to integer address space "
81 "failed. Consider adding memory space conversions.";
83 LLVM::LLVMFuncOp allocFuncOp
= getNotalignedAllocFn(
84 getTypeConverter(), op
->getParentWithTrait
<OpTrait::SymbolTable
>(),
86 auto results
= rewriter
.create
<LLVM::CallOp
>(loc
, allocFuncOp
, sizeBytes
);
89 castAllocFuncResult(rewriter
, loc
, results
.getResult(), memRefType
,
90 elementPtrType
, *getTypeConverter());
92 return std::make_tuple(Value(), Value());
93 Value alignedPtr
= allocatedPtr
;
95 // Compute the aligned pointer.
97 rewriter
.create
<LLVM::PtrToIntOp
>(loc
, getIndexType(), allocatedPtr
);
98 Value alignmentInt
= createAligned(rewriter
, loc
, allocatedInt
, alignment
);
100 rewriter
.create
<LLVM::IntToPtrOp
>(loc
, elementPtrType
, alignmentInt
);
103 return std::make_tuple(allocatedPtr
, alignedPtr
);
106 unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
107 MemRefType memRefType
, Operation
*op
,
108 const DataLayout
*defaultLayout
) const {
109 const DataLayout
*layout
= defaultLayout
;
110 if (const DataLayoutAnalysis
*analysis
=
111 getTypeConverter()->getDataLayoutAnalysis()) {
112 layout
= &analysis
->getAbove(op
);
114 Type elementType
= memRefType
.getElementType();
115 if (auto memRefElementType
= dyn_cast
<MemRefType
>(elementType
))
116 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType
,
118 if (auto memRefElementType
= dyn_cast
<UnrankedMemRefType
>(elementType
))
119 return getTypeConverter()->getUnrankedMemRefDescriptorSize(
120 memRefElementType
, *layout
);
121 return layout
->getTypeSize(elementType
);
124 bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
125 MemRefType type
, uint64_t factor
, Operation
*op
,
126 const DataLayout
*defaultLayout
) const {
127 uint64_t sizeDivisor
= getMemRefEltSizeInBytes(type
, op
, defaultLayout
);
128 for (unsigned i
= 0, e
= type
.getRank(); i
< e
; i
++) {
129 if (type
.isDynamicDim(i
))
131 sizeDivisor
= sizeDivisor
* type
.getDimSize(i
);
133 return sizeDivisor
% factor
== 0;
136 Value
AllocationOpLLVMLowering::allocateBufferAutoAlign(
137 ConversionPatternRewriter
&rewriter
, Location loc
, Value sizeBytes
,
138 Operation
*op
, const DataLayout
*defaultLayout
, int64_t alignment
) const {
139 Value allocAlignment
=
140 createIndexAttrConstant(rewriter
, loc
, getIndexType(), alignment
);
142 MemRefType memRefType
= getMemRefResultType(op
);
143 // Function aligned_alloc requires size to be a multiple of alignment; we pad
144 // the size to the next multiple if necessary.
145 if (!isMemRefSizeMultipleOf(memRefType
, alignment
, op
, defaultLayout
))
146 sizeBytes
= createAligned(rewriter
, loc
, sizeBytes
, allocAlignment
);
148 Type elementPtrType
= this->getElementPtrType(memRefType
);
149 LLVM::LLVMFuncOp allocFuncOp
= getAlignedAllocFn(
150 getTypeConverter(), op
->getParentWithTrait
<OpTrait::SymbolTable
>(),
152 auto results
= rewriter
.create
<LLVM::CallOp
>(
153 loc
, allocFuncOp
, ValueRange({allocAlignment
, sizeBytes
}));
155 return castAllocFuncResult(rewriter
, loc
, results
.getResult(), memRefType
,
156 elementPtrType
, *getTypeConverter());
159 void AllocLikeOpLLVMLowering::setRequiresNumElements() {
160 requiresNumElements
= true;
163 LogicalResult
AllocLikeOpLLVMLowering::matchAndRewrite(
164 Operation
*op
, ArrayRef
<Value
> operands
,
165 ConversionPatternRewriter
&rewriter
) const {
166 MemRefType memRefType
= getMemRefResultType(op
);
167 if (!isConvertibleAndHasIdentityMaps(memRefType
))
168 return rewriter
.notifyMatchFailure(op
, "incompatible memref type");
169 auto loc
= op
->getLoc();
171 // Get actual sizes of the memref as values: static sizes are constant
172 // values and dynamic sizes are passed to 'alloc' as operands. In case of
173 // zero-dimensional memref, assume a scalar (size 1).
174 SmallVector
<Value
, 4> sizes
;
175 SmallVector
<Value
, 4> strides
;
178 this->getMemRefDescriptorSizes(loc
, memRefType
, operands
, rewriter
, sizes
,
179 strides
, size
, !requiresNumElements
);
181 // Allocate the underlying buffer.
182 auto [allocatedPtr
, alignedPtr
] =
183 this->allocateBuffer(rewriter
, loc
, size
, op
);
185 if (!allocatedPtr
|| !alignedPtr
)
186 return rewriter
.notifyMatchFailure(loc
,
187 "underlying buffer allocation failed");
189 // Create the MemRef descriptor.
190 auto memRefDescriptor
= this->createMemRefDescriptor(
191 loc
, memRefType
, allocatedPtr
, alignedPtr
, sizes
, strides
, rewriter
);
193 // Return the final value of the descriptor.
194 rewriter
.replaceOp(op
, {memRefDescriptor
});