Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / MemRefToLLVM / MemRefToLLVM.cpp
blob4bfa536cc8a44a4f11468674af7be47b92a79d9d
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <optional>
32 namespace mlir {
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
37 using namespace mlir;
39 namespace {
41 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42 return !ShapedType::isDynamic(strideOrOffset);
45 LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46 ModuleOp module) {
47 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
49 if (useGenericFn)
50 return LLVM::lookupOrCreateGenericFreeFn(module);
52 return LLVM::lookupOrCreateFreeFn(module);
55 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56 AllocOpLowering(const LLVMTypeConverter &converter)
57 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
58 converter) {}
59 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
60 Location loc, Value sizeBytes,
61 Operation *op) const override {
62 return allocateBufferManuallyAlign(
63 rewriter, loc, sizeBytes, op,
64 getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
68 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69 AlignedAllocOpLowering(const LLVMTypeConverter &converter)
70 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
71 converter) {}
72 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
73 Location loc, Value sizeBytes,
74 Operation *op) const override {
75 Value ptr = allocateBufferAutoAlign(
76 rewriter, loc, sizeBytes, op, &defaultLayout,
77 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
78 &defaultLayout));
79 if (!ptr)
80 return std::make_tuple(Value(), Value());
81 return std::make_tuple(ptr, ptr);
84 private:
85 /// Default layout to use in absence of the corresponding analysis.
86 DataLayout defaultLayout;
89 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90 AllocaOpLowering(const LLVMTypeConverter &converter)
91 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92 converter) {
93 setRequiresNumElements();
96 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97 /// is set to null for stack allocations. `accessAlignment` is set if
98 /// alignment is needed post allocation (for eg. in conjunction with malloc).
99 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100 Location loc, Value size,
101 Operation *op) const override {
103 // With alloca, one gets a pointer to the element type right away.
104 // For stack allocations.
105 auto allocaOp = cast<memref::AllocaOp>(op);
106 auto elementType =
107 typeConverter->convertType(allocaOp.getType().getElementType());
108 unsigned addrSpace =
109 *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110 auto elementPtrType =
111 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
113 auto allocatedElementPtr =
114 rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115 allocaOp.getAlignment().value_or(0));
117 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
121 struct AllocaScopeOpLowering
122 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
123 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
125 LogicalResult
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter) const override {
128 OpBuilder::InsertionGuard guard(rewriter);
129 Location loc = allocaScopeOp.getLoc();
131 // Split the current block before the AllocaScopeOp to create the inlining
132 // point.
133 auto *currentBlock = rewriter.getInsertionBlock();
134 auto *remainingOpsBlock =
135 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
136 Block *continueBlock;
137 if (allocaScopeOp.getNumResults() == 0) {
138 continueBlock = remainingOpsBlock;
139 } else {
140 continueBlock = rewriter.createBlock(
141 remainingOpsBlock, allocaScopeOp.getResultTypes(),
142 SmallVector<Location>(allocaScopeOp->getNumResults(),
143 allocaScopeOp.getLoc()));
144 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
147 // Inline body region.
148 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
152 // Save stack and then branch into the body of the region.
153 rewriter.setInsertionPointToEnd(currentBlock);
154 auto stackSaveOp =
155 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
156 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
158 // Replace the alloca_scope return with a branch that jumps out of the body.
159 // Stack restore before leaving the body region.
160 rewriter.setInsertionPointToEnd(afterBody);
161 auto returnOp =
162 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
163 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164 returnOp, returnOp.getResults(), continueBlock);
166 // Insert stack restore before jumping out the body of the region.
167 rewriter.setInsertionPoint(branchOp);
168 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
170 // Replace the op with values return from the body region.
171 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
173 return success();
177 struct AssumeAlignmentOpLowering
178 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
179 using ConvertOpToLLVMPattern<
180 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182 : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
184 LogicalResult
185 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const override {
187 Value memref = adaptor.getMemref();
188 unsigned alignment = op.getAlignment();
189 auto loc = op.getLoc();
191 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193 rewriter);
195 // Emit llvm.assume(memref & (alignment - 1) == 0).
197 // This relies on LLVM's CSE optimization (potentially after SROA), since
198 // after CSE all memref instances should get de-duplicated into the same
199 // pointer SSA value.
200 MemRefDescriptor memRefDescriptor(memref);
201 auto intPtrType =
202 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
203 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
204 Value mask =
205 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
206 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
207 rewriter.create<LLVM::AssumeOp>(
208 loc, rewriter.create<LLVM::ICmpOp>(
209 loc, LLVM::ICmpPredicate::eq,
210 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
212 rewriter.eraseOp(op);
213 return success();
217 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
218 // The memref descriptor being an SSA value, there is no need to clean it up
219 // in any way.
220 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
221 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
223 explicit DeallocOpLowering(const LLVMTypeConverter &converter)
224 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
226 LogicalResult
227 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const override {
229 // Insert the `free` declaration if it is not already present.
230 LLVM::LLVMFuncOp freeFunc =
231 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
232 Value allocatedPtr;
233 if (auto unrankedTy =
234 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
235 auto elementPtrTy = LLVM::LLVMPointerType::get(
236 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
237 allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
238 rewriter, op.getLoc(),
239 UnrankedMemRefDescriptor(adaptor.getMemref())
240 .memRefDescPtr(rewriter, op.getLoc()),
241 elementPtrTy);
242 } else {
243 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
244 .allocatedPtr(rewriter, op.getLoc());
246 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
247 return success();
251 // A `dim` is converted to a constant for static sizes and to an access to the
252 // size stored in the memref descriptor for dynamic sizes.
253 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
254 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
256 LogicalResult
257 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter) const override {
259 Type operandType = dimOp.getSource().getType();
260 if (isa<UnrankedMemRefType>(operandType)) {
261 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
262 operandType, dimOp, adaptor.getOperands(), rewriter);
263 if (failed(extractedSize))
264 return failure();
265 rewriter.replaceOp(dimOp, {*extractedSize});
266 return success();
268 if (isa<MemRefType>(operandType)) {
269 rewriter.replaceOp(
270 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
271 adaptor.getOperands(), rewriter)});
272 return success();
274 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
277 private:
278 FailureOr<Value>
279 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
280 OpAdaptor adaptor,
281 ConversionPatternRewriter &rewriter) const {
282 Location loc = dimOp.getLoc();
284 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
285 auto scalarMemRefType =
286 MemRefType::get({}, unrankedMemRefType.getElementType());
287 FailureOr<unsigned> maybeAddressSpace =
288 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
289 if (failed(maybeAddressSpace)) {
290 dimOp.emitOpError("memref memory space must be convertible to an integer "
291 "address space");
292 return failure();
294 unsigned addressSpace = *maybeAddressSpace;
296 // Extract pointer to the underlying ranked descriptor and bitcast it to a
297 // memref<element_type> descriptor pointer to minimize the number of GEP
298 // operations.
299 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
300 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
302 Type elementType = typeConverter->convertType(scalarMemRefType);
304 // Get pointer to offset field of memref<element_type> descriptor.
305 auto indexPtrTy =
306 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
307 Value offsetPtr = rewriter.create<LLVM::GEPOp>(
308 loc, indexPtrTy, elementType, underlyingRankedDesc,
309 ArrayRef<LLVM::GEPArg>{0, 2});
311 // The size value that we have to extract can be obtained using GEPop with
312 // `dimOp.index() + 1` index argument.
313 Value idxPlusOne = rewriter.create<LLVM::AddOp>(
314 loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
315 adaptor.getIndex());
316 Value sizePtr = rewriter.create<LLVM::GEPOp>(
317 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
318 idxPlusOne);
319 return rewriter
320 .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
321 .getResult();
324 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
325 if (auto idx = dimOp.getConstantIndex())
326 return idx;
328 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
329 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
331 return std::nullopt;
334 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
335 OpAdaptor adaptor,
336 ConversionPatternRewriter &rewriter) const {
337 Location loc = dimOp.getLoc();
339 // Take advantage if index is constant.
340 MemRefType memRefType = cast<MemRefType>(operandType);
341 Type indexType = getIndexType();
342 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
343 int64_t i = *index;
344 if (i >= 0 && i < memRefType.getRank()) {
345 if (memRefType.isDynamicDim(i)) {
346 // extract dynamic size from the memref descriptor.
347 MemRefDescriptor descriptor(adaptor.getSource());
348 return descriptor.size(rewriter, loc, i);
350 // Use constant for static size.
351 int64_t dimSize = memRefType.getDimSize(i);
352 return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
355 Value index = adaptor.getIndex();
356 int64_t rank = memRefType.getRank();
357 MemRefDescriptor memrefDescriptor(adaptor.getSource());
358 return memrefDescriptor.size(rewriter, loc, index, rank);
362 /// Common base for load and store operations on MemRefs. Restricts the match
363 /// to supported MemRef types. Provides functionality to emit code accessing a
364 /// specific element of the underlying data buffer.
365 template <typename Derived>
366 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
367 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
368 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
369 using Base = LoadStoreOpLowering<Derived>;
371 LogicalResult match(Derived op) const override {
372 MemRefType type = op.getMemRefType();
373 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
377 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
378 /// retried until it succeeds in atomically storing a new value into memory.
380 /// +---------------------------------+
381 /// | <code before the AtomicRMWOp> |
382 /// | <compute initial %loaded> |
383 /// | cf.br loop(%loaded) |
384 /// +---------------------------------+
385 /// |
386 /// -------| |
387 /// | v v
388 /// | +--------------------------------+
389 /// | | loop(%loaded): |
390 /// | | <body contents> |
391 /// | | %pair = cmpxchg |
392 /// | | %ok = %pair[0] |
393 /// | | %new = %pair[1] |
394 /// | | cf.cond_br %ok, end, loop(%new) |
395 /// | +--------------------------------+
396 /// | | |
397 /// |----------- |
398 /// v
399 /// +--------------------------------+
400 /// | end: |
401 /// | <code after the AtomicRMWOp> |
402 /// +--------------------------------+
404 struct GenericAtomicRMWOpLowering
405 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
406 using Base::Base;
408 LogicalResult
409 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
410 ConversionPatternRewriter &rewriter) const override {
411 auto loc = atomicOp.getLoc();
412 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
414 // Split the block into initial, loop, and ending parts.
415 auto *initBlock = rewriter.getInsertionBlock();
416 auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
417 loopBlock->addArgument(valueType, loc);
419 auto *endBlock =
420 rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
422 // Compute the loaded value and branch to the loop block.
423 rewriter.setInsertionPointToEnd(initBlock);
424 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
425 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
426 adaptor.getIndices(), rewriter);
427 Value init = rewriter.create<LLVM::LoadOp>(
428 loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
429 rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
431 // Prepare the body of the loop block.
432 rewriter.setInsertionPointToStart(loopBlock);
434 // Clone the GenericAtomicRMWOp region and extract the result.
435 auto loopArgument = loopBlock->getArgument(0);
436 IRMapping mapping;
437 mapping.map(atomicOp.getCurrentValue(), loopArgument);
438 Block &entryBlock = atomicOp.body().front();
439 for (auto &nestedOp : entryBlock.without_terminator()) {
440 Operation *clone = rewriter.clone(nestedOp, mapping);
441 mapping.map(nestedOp.getResults(), clone->getResults());
443 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
445 // Prepare the epilog of the loop block.
446 // Append the cmpxchg op to the end of the loop block.
447 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
448 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
449 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
450 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
451 // Extract the %new_loaded and %ok values from the pair.
452 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
453 Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
455 // Conditionally branch to the end or back to the loop depending on %ok.
456 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
457 loopBlock, newLoaded);
459 rewriter.setInsertionPointToEnd(endBlock);
461 // The 'result' of the atomic_rmw op is the newly loaded value.
462 rewriter.replaceOp(atomicOp, {newLoaded});
464 return success();
468 /// Returns the LLVM type of the global variable given the memref type `type`.
469 static Type
470 convertGlobalMemrefTypeToLLVM(MemRefType type,
471 const LLVMTypeConverter &typeConverter) {
472 // LLVM type for a global memref will be a multi-dimension array. For
473 // declarations or uninitialized global memrefs, we can potentially flatten
474 // this to a 1D array. However, for memref.global's with an initial value,
475 // we do not intend to flatten the ElementsAttribute when going from std ->
476 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
477 Type elementType = typeConverter.convertType(type.getElementType());
478 Type arrayTy = elementType;
479 // Shape has the outermost dim at index 0, so need to walk it backwards
480 for (int64_t dim : llvm::reverse(type.getShape()))
481 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
482 return arrayTy;
485 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
486 struct GlobalMemrefOpLowering
487 : public ConvertOpToLLVMPattern<memref::GlobalOp> {
488 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
490 LogicalResult
491 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
492 ConversionPatternRewriter &rewriter) const override {
493 MemRefType type = global.getType();
494 if (!isConvertibleAndHasIdentityMaps(type))
495 return failure();
497 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
499 LLVM::Linkage linkage =
500 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
502 Attribute initialValue = nullptr;
503 if (!global.isExternal() && !global.isUninitialized()) {
504 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
505 initialValue = elementsAttr;
507 // For scalar memrefs, the global variable created is of the element type,
508 // so unpack the elements attribute to extract the value.
509 if (type.getRank() == 0)
510 initialValue = elementsAttr.getSplatValue<Attribute>();
513 uint64_t alignment = global.getAlignment().value_or(0);
514 FailureOr<unsigned> addressSpace =
515 getTypeConverter()->getMemRefAddressSpace(type);
516 if (failed(addressSpace))
517 return global.emitOpError(
518 "memory space cannot be converted to an integer address space");
519 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
520 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
521 initialValue, alignment, *addressSpace);
522 if (!global.isExternal() && global.isUninitialized()) {
523 rewriter.createBlock(&newGlobal.getInitializerRegion());
524 Value undef[] = {
525 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
526 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
528 return success();
532 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
533 /// the first element stashed into the descriptor. This reuses
534 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
535 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
536 GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
537 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
538 converter) {}
540 /// Buffer "allocation" for memref.get_global op is getting the address of
541 /// the global variable referenced.
542 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
543 Location loc, Value sizeBytes,
544 Operation *op) const override {
545 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
546 MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
548 // This is called after a type conversion, which would have failed if this
549 // call fails.
550 FailureOr<unsigned> maybeAddressSpace =
551 getTypeConverter()->getMemRefAddressSpace(type);
552 if (failed(maybeAddressSpace))
553 return std::make_tuple(Value(), Value());
554 unsigned memSpace = *maybeAddressSpace;
556 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
557 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
558 auto addressOf =
559 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
561 // Get the address of the first element in the array by creating a GEP with
562 // the address of the GV as the base, and (rank + 1) number of 0 indices.
563 auto gep = rewriter.create<LLVM::GEPOp>(
564 loc, ptrTy, arrayTy, addressOf,
565 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
567 // We do not expect the memref obtained using `memref.get_global` to be
568 // ever deallocated. Set the allocated pointer to be known bad value to
569 // help debug if that ever happens.
570 auto intPtrType = getIntPtrType(memSpace);
571 Value deadBeefConst =
572 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
573 auto deadBeefPtr =
574 rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
576 // Both allocated and aligned pointers are same. We could potentially stash
577 // a nullptr for the allocated pointer since we do not expect any dealloc.
578 return std::make_tuple(deadBeefPtr, gep);
582 // Load operation is lowered to obtaining a pointer to the indexed element
583 // and loading it.
584 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
585 using Base::Base;
587 LogicalResult
588 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
589 ConversionPatternRewriter &rewriter) const override {
590 auto type = loadOp.getMemRefType();
592 Value dataPtr =
593 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
594 adaptor.getIndices(), rewriter);
595 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
596 loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
597 false, loadOp.getNontemporal());
598 return success();
602 // Store operation is lowered to obtaining a pointer to the indexed element,
603 // and storing the given value to it.
604 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
605 using Base::Base;
607 LogicalResult
608 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter) const override {
610 auto type = op.getMemRefType();
612 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
613 adaptor.getIndices(), rewriter);
614 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
615 0, false, op.getNontemporal());
616 return success();
620 // The prefetch operation is lowered in a way similar to the load operation
621 // except that the llvm.prefetch operation is used for replacement.
622 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
623 using Base::Base;
625 LogicalResult
626 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
627 ConversionPatternRewriter &rewriter) const override {
628 auto type = prefetchOp.getMemRefType();
629 auto loc = prefetchOp.getLoc();
631 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
632 adaptor.getIndices(), rewriter);
634 // Replace with llvm.prefetch.
635 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
636 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
637 IntegerAttr isData =
638 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
639 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
640 localityHint, isData);
641 return success();
645 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
646 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
648 LogicalResult
649 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
650 ConversionPatternRewriter &rewriter) const override {
651 Location loc = op.getLoc();
652 Type operandType = op.getMemref().getType();
653 if (dyn_cast<UnrankedMemRefType>(operandType)) {
654 UnrankedMemRefDescriptor desc(adaptor.getMemref());
655 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
656 return success();
658 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
659 Type indexType = getIndexType();
660 rewriter.replaceOp(op,
661 {createIndexAttrConstant(rewriter, loc, indexType,
662 rankedMemRefType.getRank())});
663 return success();
665 return failure();
669 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
670 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
672 LogicalResult match(memref::CastOp memRefCastOp) const override {
673 Type srcType = memRefCastOp.getOperand().getType();
674 Type dstType = memRefCastOp.getType();
676 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
677 // used for type erasure. For now they must preserve underlying element type
678 // and require source and result type to have the same rank. Therefore,
679 // perform a sanity check that the underlying structs are the same. Once op
680 // semantics are relaxed we can revisit.
681 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
682 return success(typeConverter->convertType(srcType) ==
683 typeConverter->convertType(dstType));
685 // At least one of the operands is unranked type
686 assert(isa<UnrankedMemRefType>(srcType) ||
687 isa<UnrankedMemRefType>(dstType));
689 // Unranked to unranked cast is disallowed
690 return !(isa<UnrankedMemRefType>(srcType) &&
691 isa<UnrankedMemRefType>(dstType))
692 ? success()
693 : failure();
696 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
697 ConversionPatternRewriter &rewriter) const override {
698 auto srcType = memRefCastOp.getOperand().getType();
699 auto dstType = memRefCastOp.getType();
700 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
701 auto loc = memRefCastOp.getLoc();
703 // For ranked/ranked case, just keep the original descriptor.
704 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
705 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
707 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
708 // Casting ranked to unranked memref type
709 // Set the rank in the destination from the memref type
710 // Allocate space on the stack and copy the src memref descriptor
711 // Set the ptr in the destination to the stack space
712 auto srcMemRefType = cast<MemRefType>(srcType);
713 int64_t rank = srcMemRefType.getRank();
714 // ptr = AllocaOp sizeof(MemRefDescriptor)
715 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
716 loc, adaptor.getSource(), rewriter);
718 // rank = ConstantOp srcRank
719 auto rankVal = rewriter.create<LLVM::ConstantOp>(
720 loc, getIndexType(), rewriter.getIndexAttr(rank));
721 // undef = UndefOp
722 UnrankedMemRefDescriptor memRefDesc =
723 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
724 // d1 = InsertValueOp undef, rank, 0
725 memRefDesc.setRank(rewriter, loc, rankVal);
726 // d2 = InsertValueOp d1, ptr, 1
727 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
728 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
730 } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
731 // Casting from unranked type to ranked.
732 // The operation is assumed to be doing a correct cast. If the destination
733 // type mismatches the unranked the type, it is undefined behavior.
734 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
735 // ptr = ExtractValueOp src, 1
736 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
738 // struct = LoadOp ptr
739 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
740 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
741 } else {
742 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
747 /// Pattern to lower a `memref.copy` to llvm.
749 /// For memrefs with identity layouts, the copy is lowered to the llvm
750 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
751 /// to the generic `MemrefCopyFn`.
752 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
753 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
755 LogicalResult
756 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
757 ConversionPatternRewriter &rewriter) const {
758 auto loc = op.getLoc();
759 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
761 MemRefDescriptor srcDesc(adaptor.getSource());
763 // Compute number of elements.
764 Value numElements = rewriter.create<LLVM::ConstantOp>(
765 loc, getIndexType(), rewriter.getIndexAttr(1));
766 for (int pos = 0; pos < srcType.getRank(); ++pos) {
767 auto size = srcDesc.size(rewriter, loc, pos);
768 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
771 // Get element size.
772 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
773 // Compute total.
774 Value totalSize =
775 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
777 Type elementType = typeConverter->convertType(srcType.getElementType());
779 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
780 Value srcOffset = srcDesc.offset(rewriter, loc);
781 Value srcPtr = rewriter.create<LLVM::GEPOp>(
782 loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
783 MemRefDescriptor targetDesc(adaptor.getTarget());
784 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
785 Value targetOffset = targetDesc.offset(rewriter, loc);
786 Value targetPtr = rewriter.create<LLVM::GEPOp>(
787 loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
788 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
789 /*isVolatile=*/false);
790 rewriter.eraseOp(op);
792 return success();
795 LogicalResult
796 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
797 ConversionPatternRewriter &rewriter) const {
798 auto loc = op.getLoc();
799 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
800 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
802 // First make sure we have an unranked memref descriptor representation.
803 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
804 auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
805 type.getRank());
806 auto *typeConverter = getTypeConverter();
807 auto ptr =
808 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
810 auto unrankedType =
811 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
812 return UnrankedMemRefDescriptor::pack(
813 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
816 // Save stack position before promoting descriptors
817 auto stackSaveOp =
818 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
820 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
821 Value unrankedSource =
822 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
823 : adaptor.getSource();
824 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
825 Value unrankedTarget =
826 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
827 : adaptor.getTarget();
829 // Now promote the unranked descriptors to the stack.
830 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
831 rewriter.getIndexAttr(1));
832 auto promote = [&](Value desc) {
833 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
834 auto allocated =
835 rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
836 rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
837 return allocated;
840 auto sourcePtr = promote(unrankedSource);
841 auto targetPtr = promote(unrankedTarget);
843 // Derive size from llvm.getelementptr which will account for any
844 // potential alignment
845 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
846 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
847 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
848 rewriter.create<LLVM::CallOp>(loc, copyFn,
849 ValueRange{elemSize, sourcePtr, targetPtr});
851 // Restore stack used for descriptors
852 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
854 rewriter.eraseOp(op);
856 return success();
859 LogicalResult
860 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter) const override {
862 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
863 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
865 auto isContiguousMemrefType = [&](BaseMemRefType type) {
866 auto memrefType = dyn_cast<mlir::MemRefType>(type);
867 // We can use memcpy for memrefs if they have an identity layout or are
868 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
869 // special case handled by memrefCopy.
870 return memrefType &&
871 (memrefType.getLayout().isIdentity() ||
872 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
873 memref::isStaticShapeAndContiguousRowMajor(memrefType)));
876 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
877 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
879 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
883 struct MemorySpaceCastOpLowering
884 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
885 using ConvertOpToLLVMPattern<
886 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
888 LogicalResult
889 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter) const override {
891 Location loc = op.getLoc();
893 Type resultType = op.getDest().getType();
894 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
895 auto resultDescType =
896 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
897 Type newPtrType = resultDescType.getBody()[0];
899 SmallVector<Value> descVals;
900 MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
901 descVals);
902 descVals[0] =
903 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
904 descVals[1] =
905 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
906 Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
907 resultTypeR, descVals);
908 rewriter.replaceOp(op, result);
909 return success();
911 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
912 // Since the type converter won't be doing this for us, get the address
913 // space.
914 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
915 FailureOr<unsigned> maybeSourceAddrSpace =
916 getTypeConverter()->getMemRefAddressSpace(sourceType);
917 if (failed(maybeSourceAddrSpace))
918 return rewriter.notifyMatchFailure(loc,
919 "non-integer source address space");
920 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
921 FailureOr<unsigned> maybeResultAddrSpace =
922 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
923 if (failed(maybeResultAddrSpace))
924 return rewriter.notifyMatchFailure(loc,
925 "non-integer result address space");
926 unsigned resultAddrSpace = *maybeResultAddrSpace;
928 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
929 Value rank = sourceDesc.rank(rewriter, loc);
930 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
932 // Create and allocate storage for new memref descriptor.
933 auto result = UnrankedMemRefDescriptor::undef(
934 rewriter, loc, typeConverter->convertType(resultTypeU));
935 result.setRank(rewriter, loc, rank);
936 SmallVector<Value, 1> sizes;
937 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
938 result, resultAddrSpace, sizes);
939 Value resultUnderlyingSize = sizes.front();
940 Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
941 loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
942 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
944 // Copy pointers, performing address space casts.
945 auto sourceElemPtrType =
946 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
947 auto resultElemPtrType =
948 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
950 Value allocatedPtr = sourceDesc.allocatedPtr(
951 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
952 Value alignedPtr =
953 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
954 sourceUnderlyingDesc, sourceElemPtrType);
955 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
956 loc, resultElemPtrType, allocatedPtr);
957 alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
958 loc, resultElemPtrType, alignedPtr);
960 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
961 resultElemPtrType, allocatedPtr);
962 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
963 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
965 // Copy all the index-valued operands.
966 Value sourceIndexVals =
967 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
968 sourceUnderlyingDesc, sourceElemPtrType);
969 Value resultIndexVals =
970 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
971 resultUnderlyingDesc, resultElemPtrType);
973 int64_t bytesToSkip =
974 2 * llvm::divideCeil(
975 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
976 Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
977 loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
978 Value copySize = rewriter.create<LLVM::SubOp>(
979 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
980 rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
981 copySize, /*isVolatile=*/false);
983 rewriter.replaceOp(op, ValueRange{result});
984 return success();
986 return rewriter.notifyMatchFailure(loc, "unexpected memref type");
990 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
991 /// memref type. In unranked case, the fields are extracted from the underlying
992 /// ranked descriptor.
993 static void extractPointersAndOffset(Location loc,
994 ConversionPatternRewriter &rewriter,
995 const LLVMTypeConverter &typeConverter,
996 Value originalOperand,
997 Value convertedOperand,
998 Value *allocatedPtr, Value *alignedPtr,
999 Value *offset = nullptr) {
1000 Type operandType = originalOperand.getType();
1001 if (isa<MemRefType>(operandType)) {
1002 MemRefDescriptor desc(convertedOperand);
1003 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
1004 *alignedPtr = desc.alignedPtr(rewriter, loc);
1005 if (offset != nullptr)
1006 *offset = desc.offset(rewriter, loc);
1007 return;
1010 // These will all cause assert()s on unconvertible types.
1011 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1012 cast<UnrankedMemRefType>(operandType));
1013 auto elementPtrType =
1014 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1016 // Extract pointer to the underlying ranked memref descriptor and cast it to
1017 // ElemType**.
1018 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1019 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1021 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1022 rewriter, loc, underlyingDescPtr, elementPtrType);
1023 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1024 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1025 if (offset != nullptr) {
1026 *offset = UnrankedMemRefDescriptor::offset(
1027 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1031 struct MemRefReinterpretCastOpLowering
1032 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1033 using ConvertOpToLLVMPattern<
1034 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1036 LogicalResult
1037 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1038 ConversionPatternRewriter &rewriter) const override {
1039 Type srcType = castOp.getSource().getType();
1041 Value descriptor;
1042 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1043 adaptor, &descriptor)))
1044 return failure();
1045 rewriter.replaceOp(castOp, {descriptor});
1046 return success();
1049 private:
1050 LogicalResult convertSourceMemRefToDescriptor(
1051 ConversionPatternRewriter &rewriter, Type srcType,
1052 memref::ReinterpretCastOp castOp,
1053 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1054 MemRefType targetMemRefType =
1055 cast<MemRefType>(castOp.getResult().getType());
1056 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1057 typeConverter->convertType(targetMemRefType));
1058 if (!llvmTargetDescriptorTy)
1059 return failure();
1061 // Create descriptor.
1062 Location loc = castOp.getLoc();
1063 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1065 // Set allocated and aligned pointers.
1066 Value allocatedPtr, alignedPtr;
1067 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1068 castOp.getSource(), adaptor.getSource(),
1069 &allocatedPtr, &alignedPtr);
1070 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1071 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1073 // Set offset.
1074 if (castOp.isDynamicOffset(0))
1075 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1076 else
1077 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1079 // Set sizes and strides.
1080 unsigned dynSizeId = 0;
1081 unsigned dynStrideId = 0;
1082 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1083 if (castOp.isDynamicSize(i))
1084 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1085 else
1086 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1088 if (castOp.isDynamicStride(i))
1089 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1090 else
1091 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1093 *descriptor = desc;
1094 return success();
1098 struct MemRefReshapeOpLowering
1099 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1100 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1102 LogicalResult
1103 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter) const override {
1105 Type srcType = reshapeOp.getSource().getType();
1107 Value descriptor;
1108 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1109 adaptor, &descriptor)))
1110 return failure();
1111 rewriter.replaceOp(reshapeOp, {descriptor});
1112 return success();
1115 private:
1116 LogicalResult
1117 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1118 Type srcType, memref::ReshapeOp reshapeOp,
1119 memref::ReshapeOp::Adaptor adaptor,
1120 Value *descriptor) const {
1121 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1122 if (shapeMemRefType.hasStaticShape()) {
1123 MemRefType targetMemRefType =
1124 cast<MemRefType>(reshapeOp.getResult().getType());
1125 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1126 typeConverter->convertType(targetMemRefType));
1127 if (!llvmTargetDescriptorTy)
1128 return failure();
1130 // Create descriptor.
1131 Location loc = reshapeOp.getLoc();
1132 auto desc =
1133 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1135 // Set allocated and aligned pointers.
1136 Value allocatedPtr, alignedPtr;
1137 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1138 reshapeOp.getSource(), adaptor.getSource(),
1139 &allocatedPtr, &alignedPtr);
1140 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1141 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1143 // Extract the offset and strides from the type.
1144 int64_t offset;
1145 SmallVector<int64_t> strides;
1146 if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1147 return rewriter.notifyMatchFailure(
1148 reshapeOp, "failed to get stride and offset exprs");
1150 if (!isStaticStrideOrOffset(offset))
1151 return rewriter.notifyMatchFailure(reshapeOp,
1152 "dynamic offset is unsupported");
1154 desc.setConstantOffset(rewriter, loc, offset);
1156 assert(targetMemRefType.getLayout().isIdentity() &&
1157 "Identity layout map is a precondition of a valid reshape op");
1159 Type indexType = getIndexType();
1160 Value stride = nullptr;
1161 int64_t targetRank = targetMemRefType.getRank();
1162 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1163 if (!ShapedType::isDynamic(strides[i])) {
1164 // If the stride for this dimension is dynamic, then use the product
1165 // of the sizes of the inner dimensions.
1166 stride =
1167 createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1168 } else if (!stride) {
1169 // `stride` is null only in the first iteration of the loop. However,
1170 // since the target memref has an identity layout, we can safely set
1171 // the innermost stride to 1.
1172 stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1175 Value dimSize;
1176 // If the size of this dimension is dynamic, then load it at runtime
1177 // from the shape operand.
1178 if (!targetMemRefType.isDynamicDim(i)) {
1179 dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1180 targetMemRefType.getDimSize(i));
1181 } else {
1182 Value shapeOp = reshapeOp.getShape();
1183 Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1184 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1185 Type indexType = getIndexType();
1186 if (dimSize.getType() != indexType)
1187 dimSize = typeConverter->materializeTargetConversion(
1188 rewriter, loc, indexType, dimSize);
1189 assert(dimSize && "Invalid memref element type");
1192 desc.setSize(rewriter, loc, i, dimSize);
1193 desc.setStride(rewriter, loc, i, stride);
1195 // Prepare the stride value for the next dimension.
1196 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1199 *descriptor = desc;
1200 return success();
1203 // The shape is a rank-1 tensor with unknown length.
1204 Location loc = reshapeOp.getLoc();
1205 MemRefDescriptor shapeDesc(adaptor.getShape());
1206 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1208 // Extract address space and element type.
1209 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1210 unsigned addressSpace =
1211 *getTypeConverter()->getMemRefAddressSpace(targetType);
1213 // Create the unranked memref descriptor that holds the ranked one. The
1214 // inner descriptor is allocated on stack.
1215 auto targetDesc = UnrankedMemRefDescriptor::undef(
1216 rewriter, loc, typeConverter->convertType(targetType));
1217 targetDesc.setRank(rewriter, loc, resultRank);
1218 SmallVector<Value, 4> sizes;
1219 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1220 targetDesc, addressSpace, sizes);
1221 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1222 loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1223 sizes.front());
1224 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1226 // Extract pointers and offset from the source memref.
1227 Value allocatedPtr, alignedPtr, offset;
1228 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1229 reshapeOp.getSource(), adaptor.getSource(),
1230 &allocatedPtr, &alignedPtr, &offset);
1232 // Set pointers and offset.
1233 auto elementPtrType =
1234 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1236 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1237 elementPtrType, allocatedPtr);
1238 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1239 underlyingDescPtr, elementPtrType,
1240 alignedPtr);
1241 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1242 underlyingDescPtr, elementPtrType,
1243 offset);
1245 // Use the offset pointer as base for further addressing. Copy over the new
1246 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1247 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1248 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1249 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1250 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1251 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1252 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1253 Value resultRankMinusOne =
1254 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1256 Block *initBlock = rewriter.getInsertionBlock();
1257 Type indexType = getTypeConverter()->getIndexType();
1258 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1260 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1261 {indexType, indexType}, {loc, loc});
1263 // Move the remaining initBlock ops to condBlock.
1264 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1265 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1267 rewriter.setInsertionPointToEnd(initBlock);
1268 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1269 condBlock);
1270 rewriter.setInsertionPointToStart(condBlock);
1271 Value indexArg = condBlock->getArgument(0);
1272 Value strideArg = condBlock->getArgument(1);
1274 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1275 Value pred = rewriter.create<LLVM::ICmpOp>(
1276 loc, IntegerType::get(rewriter.getContext(), 1),
1277 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1279 Block *bodyBlock =
1280 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1281 rewriter.setInsertionPointToStart(bodyBlock);
1283 // Copy size from shape to descriptor.
1284 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1285 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1286 loc, llvmIndexPtrType,
1287 typeConverter->convertType(shapeMemRefType.getElementType()),
1288 shapeOperandPtr, indexArg);
1289 Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1290 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1291 targetSizesBase, indexArg, size);
1293 // Write stride value and compute next one.
1294 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1295 targetStridesBase, indexArg, strideArg);
1296 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1298 // Decrement loop counter and branch back.
1299 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1300 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1301 condBlock);
1303 Block *remainder =
1304 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1306 // Hook up the cond exit to the remainder.
1307 rewriter.setInsertionPointToEnd(condBlock);
1308 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1309 remainder, std::nullopt);
1311 // Reset position to beginning of new remainder block.
1312 rewriter.setInsertionPointToStart(remainder);
1314 *descriptor = targetDesc;
1315 return success();
1319 /// RessociatingReshapeOp must be expanded before we reach this stage.
1320 /// Report that information.
1321 template <typename ReshapeOp>
1322 class ReassociatingReshapeOpConversion
1323 : public ConvertOpToLLVMPattern<ReshapeOp> {
1324 public:
1325 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1326 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1328 LogicalResult
1329 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1330 ConversionPatternRewriter &rewriter) const override {
1331 return rewriter.notifyMatchFailure(
1332 reshapeOp,
1333 "reassociation operations should have been expanded beforehand");
1337 /// Subviews must be expanded before we reach this stage.
1338 /// Report that information.
1339 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1340 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1342 LogicalResult
1343 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter) const override {
1345 return rewriter.notifyMatchFailure(
1346 subViewOp, "subview operations should have been expanded beforehand");
1350 /// Conversion pattern that transforms a transpose op into:
1351 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1352 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1353 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1354 /// and stride. Size and stride are permutations of the original values.
1355 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1356 /// The transpose op is replaced by the alloca'ed pointer.
1357 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1358 public:
1359 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1361 LogicalResult
1362 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1363 ConversionPatternRewriter &rewriter) const override {
1364 auto loc = transposeOp.getLoc();
1365 MemRefDescriptor viewMemRef(adaptor.getIn());
1367 // No permutation, early exit.
1368 if (transposeOp.getPermutation().isIdentity())
1369 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1371 auto targetMemRef = MemRefDescriptor::undef(
1372 rewriter, loc,
1373 typeConverter->convertType(transposeOp.getIn().getType()));
1375 // Copy the base and aligned pointers from the old descriptor to the new
1376 // one.
1377 targetMemRef.setAllocatedPtr(rewriter, loc,
1378 viewMemRef.allocatedPtr(rewriter, loc));
1379 targetMemRef.setAlignedPtr(rewriter, loc,
1380 viewMemRef.alignedPtr(rewriter, loc));
1382 // Copy the offset pointer from the old descriptor to the new one.
1383 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1385 // Iterate over the dimensions and apply size/stride permutation:
1386 // When enumerating the results of the permutation map, the enumeration
1387 // index is the index into the target dimensions and the DimExpr points to
1388 // the dimension of the source memref.
1389 for (const auto &en :
1390 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1391 int targetPos = en.index();
1392 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1393 targetMemRef.setSize(rewriter, loc, targetPos,
1394 viewMemRef.size(rewriter, loc, sourcePos));
1395 targetMemRef.setStride(rewriter, loc, targetPos,
1396 viewMemRef.stride(rewriter, loc, sourcePos));
1399 rewriter.replaceOp(transposeOp, {targetMemRef});
1400 return success();
1404 /// Conversion pattern that transforms an op into:
1405 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1406 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1407 /// and stride.
1408 /// The view op is replaced by the descriptor.
1409 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1410 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1412 // Build and return the value for the idx^th shape dimension, either by
1413 // returning the constant shape dimension or counting the proper dynamic size.
1414 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1415 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1416 Type indexType) const {
1417 assert(idx < shape.size());
1418 if (!ShapedType::isDynamic(shape[idx]))
1419 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1420 // Count the number of dynamic dims in range [0, idx]
1421 unsigned nDynamic =
1422 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1423 return dynamicSizes[nDynamic];
1426 // Build and return the idx^th stride, either by returning the constant stride
1427 // or by computing the dynamic stride from the current `runningStride` and
1428 // `nextSize`. The caller should keep a running stride and update it with the
1429 // result returned by this function.
1430 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1431 ArrayRef<int64_t> strides, Value nextSize,
1432 Value runningStride, unsigned idx, Type indexType) const {
1433 assert(idx < strides.size());
1434 if (!ShapedType::isDynamic(strides[idx]))
1435 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1436 if (nextSize)
1437 return runningStride
1438 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1439 : nextSize;
1440 assert(!runningStride);
1441 return createIndexAttrConstant(rewriter, loc, indexType, 1);
1444 LogicalResult
1445 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1446 ConversionPatternRewriter &rewriter) const override {
1447 auto loc = viewOp.getLoc();
1449 auto viewMemRefType = viewOp.getType();
1450 auto targetElementTy =
1451 typeConverter->convertType(viewMemRefType.getElementType());
1452 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1453 if (!targetDescTy || !targetElementTy ||
1454 !LLVM::isCompatibleType(targetElementTy) ||
1455 !LLVM::isCompatibleType(targetDescTy))
1456 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1457 failure();
1459 int64_t offset;
1460 SmallVector<int64_t, 4> strides;
1461 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1462 if (failed(successStrides))
1463 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1464 assert(offset == 0 && "expected offset to be 0");
1466 // Target memref must be contiguous in memory (innermost stride is 1), or
1467 // empty (special case when at least one of the memref dimensions is 0).
1468 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1469 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1470 failure();
1472 // Create the descriptor.
1473 MemRefDescriptor sourceMemRef(adaptor.getSource());
1474 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1476 // Field 1: Copy the allocated pointer, used for malloc/free.
1477 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1478 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1479 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1481 // Field 2: Copy the actual aligned pointer to payload.
1482 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1483 alignedPtr = rewriter.create<LLVM::GEPOp>(
1484 loc, alignedPtr.getType(),
1485 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1486 adaptor.getByteShift());
1488 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1490 Type indexType = getIndexType();
1491 // Field 3: The offset in the resulting type must be 0. This is
1492 // because of the type change: an offset on srcType* may not be
1493 // expressible as an offset on dstType*.
1494 targetMemRef.setOffset(
1495 rewriter, loc,
1496 createIndexAttrConstant(rewriter, loc, indexType, offset));
1498 // Early exit for 0-D corner case.
1499 if (viewMemRefType.getRank() == 0)
1500 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1502 // Fields 4 and 5: Update sizes and strides.
1503 Value stride = nullptr, nextSize = nullptr;
1504 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1505 // Update size.
1506 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1507 adaptor.getSizes(), i, indexType);
1508 targetMemRef.setSize(rewriter, loc, i, size);
1509 // Update stride.
1510 stride =
1511 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1512 targetMemRef.setStride(rewriter, loc, i, stride);
1513 nextSize = size;
1516 rewriter.replaceOp(viewOp, {targetMemRef});
1517 return success();
1521 //===----------------------------------------------------------------------===//
1522 // AtomicRMWOpLowering
1523 //===----------------------------------------------------------------------===//
1525 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1526 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1527 static std::optional<LLVM::AtomicBinOp>
1528 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1529 switch (atomicOp.getKind()) {
1530 case arith::AtomicRMWKind::addf:
1531 return LLVM::AtomicBinOp::fadd;
1532 case arith::AtomicRMWKind::addi:
1533 return LLVM::AtomicBinOp::add;
1534 case arith::AtomicRMWKind::assign:
1535 return LLVM::AtomicBinOp::xchg;
1536 case arith::AtomicRMWKind::maximumf:
1537 return LLVM::AtomicBinOp::fmax;
1538 case arith::AtomicRMWKind::maxs:
1539 return LLVM::AtomicBinOp::max;
1540 case arith::AtomicRMWKind::maxu:
1541 return LLVM::AtomicBinOp::umax;
1542 case arith::AtomicRMWKind::minimumf:
1543 return LLVM::AtomicBinOp::fmin;
1544 case arith::AtomicRMWKind::mins:
1545 return LLVM::AtomicBinOp::min;
1546 case arith::AtomicRMWKind::minu:
1547 return LLVM::AtomicBinOp::umin;
1548 case arith::AtomicRMWKind::ori:
1549 return LLVM::AtomicBinOp::_or;
1550 case arith::AtomicRMWKind::andi:
1551 return LLVM::AtomicBinOp::_and;
1552 default:
1553 return std::nullopt;
1555 llvm_unreachable("Invalid AtomicRMWKind");
1558 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1559 using Base::Base;
1561 LogicalResult
1562 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1563 ConversionPatternRewriter &rewriter) const override {
1564 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1565 if (!maybeKind)
1566 return failure();
1567 auto memRefType = atomicOp.getMemRefType();
1568 SmallVector<int64_t> strides;
1569 int64_t offset;
1570 if (failed(getStridesAndOffset(memRefType, strides, offset)))
1571 return failure();
1572 auto dataPtr =
1573 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1574 adaptor.getIndices(), rewriter);
1575 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1576 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1577 LLVM::AtomicOrdering::acq_rel);
1578 return success();
1582 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1583 class ConvertExtractAlignedPointerAsIndex
1584 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1585 public:
1586 using ConvertOpToLLVMPattern<
1587 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1589 LogicalResult
1590 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1591 OpAdaptor adaptor,
1592 ConversionPatternRewriter &rewriter) const override {
1593 BaseMemRefType sourceTy = extractOp.getSource().getType();
1595 Value alignedPtr;
1596 if (sourceTy.hasRank()) {
1597 MemRefDescriptor desc(adaptor.getSource());
1598 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1599 } else {
1600 auto elementPtrTy = LLVM::LLVMPointerType::get(
1601 rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1603 UnrankedMemRefDescriptor desc(adaptor.getSource());
1604 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1606 alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1607 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1608 elementPtrTy);
1611 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1612 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1613 return success();
1617 /// Materialize the MemRef descriptor represented by the results of
1618 /// ExtractStridedMetadataOp.
1619 class ExtractStridedMetadataOpLowering
1620 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1621 public:
1622 using ConvertOpToLLVMPattern<
1623 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1625 LogicalResult
1626 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1627 OpAdaptor adaptor,
1628 ConversionPatternRewriter &rewriter) const override {
1630 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1631 return failure();
1633 // Create the descriptor.
1634 MemRefDescriptor sourceMemRef(adaptor.getSource());
1635 Location loc = extractStridedMetadataOp.getLoc();
1636 Value source = extractStridedMetadataOp.getSource();
1638 auto sourceMemRefType = cast<MemRefType>(source.getType());
1639 int64_t rank = sourceMemRefType.getRank();
1640 SmallVector<Value> results;
1641 results.reserve(2 + rank * 2);
1643 // Base buffer.
1644 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1645 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1646 MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
1647 rewriter, loc, *getTypeConverter(),
1648 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1649 baseBuffer, alignedBuffer);
1650 results.push_back((Value)dstMemRef);
1652 // Offset.
1653 results.push_back(sourceMemRef.offset(rewriter, loc));
1655 // Sizes.
1656 for (unsigned i = 0; i < rank; ++i)
1657 results.push_back(sourceMemRef.size(rewriter, loc, i));
1658 // Strides.
1659 for (unsigned i = 0; i < rank; ++i)
1660 results.push_back(sourceMemRef.stride(rewriter, loc, i));
1662 rewriter.replaceOp(extractStridedMetadataOp, results);
1663 return success();
1667 } // namespace
1669 void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1670 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1671 // clang-format off
1672 patterns.add<
1673 AllocaOpLowering,
1674 AllocaScopeOpLowering,
1675 AtomicRMWOpLowering,
1676 AssumeAlignmentOpLowering,
1677 ConvertExtractAlignedPointerAsIndex,
1678 DimOpLowering,
1679 ExtractStridedMetadataOpLowering,
1680 GenericAtomicRMWOpLowering,
1681 GlobalMemrefOpLowering,
1682 GetGlobalMemrefOpLowering,
1683 LoadOpLowering,
1684 MemRefCastOpLowering,
1685 MemRefCopyOpLowering,
1686 MemorySpaceCastOpLowering,
1687 MemRefReinterpretCastOpLowering,
1688 MemRefReshapeOpLowering,
1689 PrefetchOpLowering,
1690 RankOpLowering,
1691 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1692 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1693 StoreOpLowering,
1694 SubViewOpLowering,
1695 TransposeOpLowering,
1696 ViewOpLowering>(converter);
1697 // clang-format on
1698 auto allocLowering = converter.getOptions().allocLowering;
1699 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1700 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1701 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1702 patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1705 namespace {
1706 struct FinalizeMemRefToLLVMConversionPass
1707 : public impl::FinalizeMemRefToLLVMConversionPassBase<
1708 FinalizeMemRefToLLVMConversionPass> {
1709 using FinalizeMemRefToLLVMConversionPassBase::
1710 FinalizeMemRefToLLVMConversionPassBase;
1712 void runOnOperation() override {
1713 Operation *op = getOperation();
1714 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1715 LowerToLLVMOptions options(&getContext(),
1716 dataLayoutAnalysis.getAtOrAbove(op));
1717 options.allocLowering =
1718 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1719 : LowerToLLVMOptions::AllocLowering::Malloc);
1721 options.useGenericFunctions = useGenericFunctions;
1723 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1724 options.overrideIndexBitwidth(indexBitwidth);
1726 LLVMTypeConverter typeConverter(&getContext(), options,
1727 &dataLayoutAnalysis);
1728 RewritePatternSet patterns(&getContext());
1729 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1730 LLVMConversionTarget target(getContext());
1731 target.addLegalOp<func::FuncOp>();
1732 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1733 signalPassFailure();
1737 /// Implement the interface to convert MemRef to LLVM.
1738 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1739 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
1740 void loadDependentDialects(MLIRContext *context) const final {
1741 context->loadDialect<LLVM::LLVMDialect>();
1744 /// Hook for derived dialect interface to provide conversion patterns
1745 /// and mark dialect legal for the conversion target.
1746 void populateConvertToLLVMConversionPatterns(
1747 ConversionTarget &target, LLVMTypeConverter &typeConverter,
1748 RewritePatternSet &patterns) const final {
1749 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1753 } // namespace
1755 void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
1756 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1757 dialect->addInterfaces<MemRefToLLVMDialectInterface>();