[Hexagon] Better detection of impossible completions to perfect shuffles
[llvm-project.git] / mlir / lib / Conversion / LLVMCommon / Pattern.cpp
blob14799f865544f44dd857946cea169cfbcb4d183a
1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/BuiltinAttributes.h"
16 using namespace mlir;
18 //===----------------------------------------------------------------------===//
19 // ConvertToLLVMPattern
20 //===----------------------------------------------------------------------===//
22 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
23 MLIRContext *context,
24 LLVMTypeConverter &typeConverter,
25 PatternBenefit benefit)
26 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
28 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
29 return static_cast<LLVMTypeConverter *>(
30 ConversionPattern::getTypeConverter());
33 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
34 return *getTypeConverter()->getDialect();
37 Type ConvertToLLVMPattern::getIndexType() const {
38 return getTypeConverter()->getIndexType();
41 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
42 return IntegerType::get(&getTypeConverter()->getContext(),
43 getTypeConverter()->getPointerBitwidth(addressSpace));
46 Type ConvertToLLVMPattern::getVoidType() const {
47 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
50 Type ConvertToLLVMPattern::getVoidPtrType() const {
51 return LLVM::LLVMPointerType::get(
52 IntegerType::get(&getTypeConverter()->getContext(), 8));
55 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
56 Location loc,
57 Type resultType,
58 int64_t value) {
59 return builder.create<LLVM::ConstantOp>(loc, resultType,
60 builder.getIndexAttr(value));
63 Value ConvertToLLVMPattern::createIndexConstant(
64 ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
65 return createIndexAttrConstant(builder, loc, getIndexType(), value);
68 Value ConvertToLLVMPattern::getStridedElementPtr(
69 Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
70 ConversionPatternRewriter &rewriter) const {
72 int64_t offset;
73 SmallVector<int64_t, 4> strides;
74 auto successStrides = getStridesAndOffset(type, strides, offset);
75 assert(succeeded(successStrides) && "unexpected non-strided memref");
76 (void)successStrides;
78 MemRefDescriptor memRefDescriptor(memRefDesc);
79 Value base = memRefDescriptor.alignedPtr(rewriter, loc);
81 Value index;
82 if (offset != 0) // Skip if offset is zero.
83 index = ShapedType::isDynamic(offset)
84 ? memRefDescriptor.offset(rewriter, loc)
85 : createIndexConstant(rewriter, loc, offset);
87 for (int i = 0, e = indices.size(); i < e; ++i) {
88 Value increment = indices[i];
89 if (strides[i] != 1) { // Skip if stride is 1.
90 Value stride = ShapedType::isDynamic(strides[i])
91 ? memRefDescriptor.stride(rewriter, loc, i)
92 : createIndexConstant(rewriter, loc, strides[i]);
93 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
95 index =
96 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
99 Type elementPtrType = memRefDescriptor.getElementPtrType();
100 return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
101 : base;
104 // Check if the MemRefType `type` is supported by the lowering. We currently
105 // only support memrefs with identity maps.
106 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
107 MemRefType type) const {
108 if (!typeConverter->convertType(type.getElementType()))
109 return false;
110 return type.getLayout().isIdentity();
113 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
114 auto elementType = type.getElementType();
115 auto structElementType = typeConverter->convertType(elementType);
116 return LLVM::LLVMPointerType::get(structElementType,
117 type.getMemorySpaceAsInt());
120 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
121 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
122 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
123 SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
124 assert(isConvertibleAndHasIdentityMaps(memRefType) &&
125 "layout maps must have been normalized away");
126 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
127 static_cast<ssize_t>(dynamicSizes.size()) &&
128 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
130 sizes.reserve(memRefType.getRank());
131 unsigned dynamicIndex = 0;
132 for (int64_t size : memRefType.getShape()) {
133 sizes.push_back(size == ShapedType::kDynamic
134 ? dynamicSizes[dynamicIndex++]
135 : createIndexConstant(rewriter, loc, size));
138 // Strides: iterate sizes in reverse order and multiply.
139 int64_t stride = 1;
140 Value runningStride = createIndexConstant(rewriter, loc, 1);
141 strides.resize(memRefType.getRank());
142 for (auto i = memRefType.getRank(); i-- > 0;) {
143 strides[i] = runningStride;
145 int64_t size = memRefType.getShape()[i];
146 if (size == 0)
147 continue;
148 bool useSizeAsStride = stride == 1;
149 if (size == ShapedType::kDynamic)
150 stride = ShapedType::kDynamic;
151 if (stride != ShapedType::kDynamic)
152 stride *= size;
154 if (useSizeAsStride)
155 runningStride = sizes[i];
156 else if (stride == ShapedType::kDynamic)
157 runningStride =
158 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
159 else
160 runningStride = createIndexConstant(rewriter, loc, stride);
163 // Buffer size in bytes.
164 Type elementPtrType = getElementPtrType(memRefType);
165 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
166 Value gepPtr =
167 rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, runningStride);
168 sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
171 Value ConvertToLLVMPattern::getSizeInBytes(
172 Location loc, Type type, ConversionPatternRewriter &rewriter) const {
173 // Compute the size of an individual element. This emits the MLIR equivalent
174 // of the following sizeof(...) implementation in LLVM IR:
175 // %0 = getelementptr %elementType* null, %indexType 1
176 // %1 = ptrtoint %elementType* %0 to %indexType
177 // which is a common pattern of getting the size of a type in bytes.
178 auto convertedPtrType =
179 LLVM::LLVMPointerType::get(typeConverter->convertType(type));
180 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
181 auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, nullPtr,
182 ArrayRef<LLVM::GEPArg>{1});
183 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
186 Value ConvertToLLVMPattern::getNumElements(
187 Location loc, ArrayRef<Value> shape,
188 ConversionPatternRewriter &rewriter) const {
189 // Compute the total number of memref elements.
190 Value numElements =
191 shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
192 for (unsigned i = 1, e = shape.size(); i < e; ++i)
193 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
194 return numElements;
197 /// Creates and populates the memref descriptor struct given all its fields.
198 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
199 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
200 ArrayRef<Value> sizes, ArrayRef<Value> strides,
201 ConversionPatternRewriter &rewriter) const {
202 auto structType = typeConverter->convertType(memRefType);
203 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
205 // Field 1: Allocated pointer, used for malloc/free.
206 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
208 // Field 2: Actual aligned pointer to payload.
209 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
211 // Field 3: Offset in aligned pointer.
212 memRefDescriptor.setOffset(rewriter, loc,
213 createIndexConstant(rewriter, loc, 0));
215 // Fields 4: Sizes.
216 for (const auto &en : llvm::enumerate(sizes))
217 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
219 // Field 5: Strides.
220 for (const auto &en : llvm::enumerate(strides))
221 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
223 return memRefDescriptor;
226 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
227 OpBuilder &builder, Location loc, TypeRange origTypes,
228 SmallVectorImpl<Value> &operands, bool toDynamic) const {
229 assert(origTypes.size() == operands.size() &&
230 "expected as may original types as operands");
232 // Find operands of unranked memref type and store them.
233 SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
234 for (unsigned i = 0, e = operands.size(); i < e; ++i)
235 if (origTypes[i].isa<UnrankedMemRefType>())
236 unrankedMemrefs.emplace_back(operands[i]);
238 if (unrankedMemrefs.empty())
239 return success();
241 // Compute allocation sizes.
242 SmallVector<Value, 4> sizes;
243 UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
244 unrankedMemrefs, sizes);
246 // Get frequently used types.
247 MLIRContext *context = builder.getContext();
248 Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
249 auto i1Type = IntegerType::get(context, 1);
250 Type indexType = getTypeConverter()->getIndexType();
252 // Find the malloc and free, or declare them if necessary.
253 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
254 LLVM::LLVMFuncOp freeFunc, mallocFunc;
255 if (toDynamic)
256 mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
257 if (!toDynamic)
258 freeFunc = LLVM::lookupOrCreateFreeFn(module);
260 // Initialize shared constants.
261 Value zero =
262 builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
264 unsigned unrankedMemrefPos = 0;
265 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
266 Type type = origTypes[i];
267 if (!type.isa<UnrankedMemRefType>())
268 continue;
269 Value allocationSize = sizes[unrankedMemrefPos++];
270 UnrankedMemRefDescriptor desc(operands[i]);
272 // Allocate memory, copy, and free the source if necessary.
273 Value memory =
274 toDynamic
275 ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
276 .getResult()
277 : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
278 /*alignment=*/0);
279 Value source = desc.memRefDescPtr(builder, loc);
280 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
281 if (!toDynamic)
282 builder.create<LLVM::CallOp>(loc, freeFunc, source);
284 // Create a new descriptor. The same descriptor can be returned multiple
285 // times, attempting to modify its pointer can lead to memory leaks
286 // (allocated twice and overwritten) or double frees (the caller does not
287 // know if the descriptor points to the same memory).
288 Type descriptorType = getTypeConverter()->convertType(type);
289 if (!descriptorType)
290 return failure();
291 auto updatedDesc =
292 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
293 Value rank = desc.rank(builder, loc);
294 updatedDesc.setRank(builder, loc, rank);
295 updatedDesc.setMemRefDescPtr(builder, loc, memory);
297 operands[i] = updatedDesc;
300 return success();
303 //===----------------------------------------------------------------------===//
304 // Detail methods
305 //===----------------------------------------------------------------------===//
307 /// Replaces the given operation "op" with a new operation of type "targetOp"
308 /// and given operands.
309 LogicalResult LLVM::detail::oneToOneRewrite(
310 Operation *op, StringRef targetOp, ValueRange operands,
311 ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
312 ConversionPatternRewriter &rewriter) {
313 unsigned numResults = op->getNumResults();
315 SmallVector<Type> resultTypes;
316 if (numResults != 0) {
317 resultTypes.push_back(
318 typeConverter.packFunctionResults(op->getResultTypes()));
319 if (!resultTypes.back())
320 return failure();
323 // Create the operation through state since we don't know its C++ type.
324 Operation *newOp =
325 rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
326 resultTypes, targetAttrs);
328 // If the operation produced 0 or 1 result, return them immediately.
329 if (numResults == 0)
330 return rewriter.eraseOp(op), success();
331 if (numResults == 1)
332 return rewriter.replaceOp(op, newOp->getResult(0)), success();
334 // Otherwise, it had been converted to an operation producing a structure.
335 // Extract individual results from the structure and return them as list.
336 SmallVector<Value, 4> results;
337 results.reserve(numResults);
338 for (unsigned i = 0; i < numResults; ++i) {
339 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
340 op->getLoc(), newOp->getResult(0), i));
342 rewriter.replaceOp(op, results);
343 return success();