[clang-tidy][use-internal-linkage]fix false positives for global overloaded operator...
[llvm-project.git] / mlir / lib / Conversion / MemRefToLLVM / MemRefToLLVM.cpp
blob86f687d7f2636ebb735a7fb9225e5d8c308ca79a
1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <optional>
32 namespace mlir {
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
37 using namespace mlir;
39 namespace {
41 bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42 return !ShapedType::isDynamic(strideOrOffset);
45 LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46 ModuleOp module) {
47 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
49 if (useGenericFn)
50 return LLVM::lookupOrCreateGenericFreeFn(module);
52 return LLVM::lookupOrCreateFreeFn(module);
55 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56 AllocOpLowering(const LLVMTypeConverter &converter)
57 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
58 converter) {}
59 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
60 Location loc, Value sizeBytes,
61 Operation *op) const override {
62 return allocateBufferManuallyAlign(
63 rewriter, loc, sizeBytes, op,
64 getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
68 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69 AlignedAllocOpLowering(const LLVMTypeConverter &converter)
70 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
71 converter) {}
72 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
73 Location loc, Value sizeBytes,
74 Operation *op) const override {
75 Value ptr = allocateBufferAutoAlign(
76 rewriter, loc, sizeBytes, op, &defaultLayout,
77 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
78 &defaultLayout));
79 if (!ptr)
80 return std::make_tuple(Value(), Value());
81 return std::make_tuple(ptr, ptr);
84 private:
85 /// Default layout to use in absence of the corresponding analysis.
86 DataLayout defaultLayout;
89 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90 AllocaOpLowering(const LLVMTypeConverter &converter)
91 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92 converter) {
93 setRequiresNumElements();
96 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97 /// is set to null for stack allocations. `accessAlignment` is set if
98 /// alignment is needed post allocation (for eg. in conjunction with malloc).
99 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100 Location loc, Value size,
101 Operation *op) const override {
103 // With alloca, one gets a pointer to the element type right away.
104 // For stack allocations.
105 auto allocaOp = cast<memref::AllocaOp>(op);
106 auto elementType =
107 typeConverter->convertType(allocaOp.getType().getElementType());
108 unsigned addrSpace =
109 *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
110 auto elementPtrType =
111 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
113 auto allocatedElementPtr =
114 rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115 allocaOp.getAlignment().value_or(0));
117 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
121 struct AllocaScopeOpLowering
122 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
123 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
125 LogicalResult
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter) const override {
128 OpBuilder::InsertionGuard guard(rewriter);
129 Location loc = allocaScopeOp.getLoc();
131 // Split the current block before the AllocaScopeOp to create the inlining
132 // point.
133 auto *currentBlock = rewriter.getInsertionBlock();
134 auto *remainingOpsBlock =
135 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
136 Block *continueBlock;
137 if (allocaScopeOp.getNumResults() == 0) {
138 continueBlock = remainingOpsBlock;
139 } else {
140 continueBlock = rewriter.createBlock(
141 remainingOpsBlock, allocaScopeOp.getResultTypes(),
142 SmallVector<Location>(allocaScopeOp->getNumResults(),
143 allocaScopeOp.getLoc()));
144 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
147 // Inline body region.
148 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
152 // Save stack and then branch into the body of the region.
153 rewriter.setInsertionPointToEnd(currentBlock);
154 auto stackSaveOp =
155 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
156 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
158 // Replace the alloca_scope return with a branch that jumps out of the body.
159 // Stack restore before leaving the body region.
160 rewriter.setInsertionPointToEnd(afterBody);
161 auto returnOp =
162 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
163 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164 returnOp, returnOp.getResults(), continueBlock);
166 // Insert stack restore before jumping out the body of the region.
167 rewriter.setInsertionPoint(branchOp);
168 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
170 // Replace the op with values return from the body region.
171 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
173 return success();
177 struct AssumeAlignmentOpLowering
178 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
179 using ConvertOpToLLVMPattern<
180 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182 : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
184 LogicalResult
185 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const override {
187 Value memref = adaptor.getMemref();
188 unsigned alignment = op.getAlignment();
189 auto loc = op.getLoc();
191 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193 rewriter);
195 // Emit llvm.assume(true) ["align"(memref, alignment)].
196 // This is more direct than ptrtoint-based checks, is explicitly supported,
197 // and works with non-integral address spaces.
198 Value trueCond =
199 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
200 Value alignmentConst =
201 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
202 rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
203 alignmentConst);
205 rewriter.eraseOp(op);
206 return success();
210 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
211 // The memref descriptor being an SSA value, there is no need to clean it up
212 // in any way.
213 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
214 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
216 explicit DeallocOpLowering(const LLVMTypeConverter &converter)
217 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
219 LogicalResult
220 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
221 ConversionPatternRewriter &rewriter) const override {
222 // Insert the `free` declaration if it is not already present.
223 LLVM::LLVMFuncOp freeFunc =
224 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225 Value allocatedPtr;
226 if (auto unrankedTy =
227 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
228 auto elementPtrTy = LLVM::LLVMPointerType::get(
229 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
230 allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
231 rewriter, op.getLoc(),
232 UnrankedMemRefDescriptor(adaptor.getMemref())
233 .memRefDescPtr(rewriter, op.getLoc()),
234 elementPtrTy);
235 } else {
236 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
237 .allocatedPtr(rewriter, op.getLoc());
239 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
240 return success();
244 // A `dim` is converted to a constant for static sizes and to an access to the
245 // size stored in the memref descriptor for dynamic sizes.
246 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
247 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
249 LogicalResult
250 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
251 ConversionPatternRewriter &rewriter) const override {
252 Type operandType = dimOp.getSource().getType();
253 if (isa<UnrankedMemRefType>(operandType)) {
254 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
255 operandType, dimOp, adaptor.getOperands(), rewriter);
256 if (failed(extractedSize))
257 return failure();
258 rewriter.replaceOp(dimOp, {*extractedSize});
259 return success();
261 if (isa<MemRefType>(operandType)) {
262 rewriter.replaceOp(
263 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
264 adaptor.getOperands(), rewriter)});
265 return success();
267 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
270 private:
271 FailureOr<Value>
272 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
273 OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter) const {
275 Location loc = dimOp.getLoc();
277 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
278 auto scalarMemRefType =
279 MemRefType::get({}, unrankedMemRefType.getElementType());
280 FailureOr<unsigned> maybeAddressSpace =
281 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
282 if (failed(maybeAddressSpace)) {
283 dimOp.emitOpError("memref memory space must be convertible to an integer "
284 "address space");
285 return failure();
287 unsigned addressSpace = *maybeAddressSpace;
289 // Extract pointer to the underlying ranked descriptor and bitcast it to a
290 // memref<element_type> descriptor pointer to minimize the number of GEP
291 // operations.
292 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
293 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
295 Type elementType = typeConverter->convertType(scalarMemRefType);
297 // Get pointer to offset field of memref<element_type> descriptor.
298 auto indexPtrTy =
299 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
300 Value offsetPtr = rewriter.create<LLVM::GEPOp>(
301 loc, indexPtrTy, elementType, underlyingRankedDesc,
302 ArrayRef<LLVM::GEPArg>{0, 2});
304 // The size value that we have to extract can be obtained using GEPop with
305 // `dimOp.index() + 1` index argument.
306 Value idxPlusOne = rewriter.create<LLVM::AddOp>(
307 loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
308 adaptor.getIndex());
309 Value sizePtr = rewriter.create<LLVM::GEPOp>(
310 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
311 idxPlusOne);
312 return rewriter
313 .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
314 .getResult();
317 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
318 if (auto idx = dimOp.getConstantIndex())
319 return idx;
321 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
322 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
324 return std::nullopt;
327 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
328 OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const {
330 Location loc = dimOp.getLoc();
332 // Take advantage if index is constant.
333 MemRefType memRefType = cast<MemRefType>(operandType);
334 Type indexType = getIndexType();
335 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
336 int64_t i = *index;
337 if (i >= 0 && i < memRefType.getRank()) {
338 if (memRefType.isDynamicDim(i)) {
339 // extract dynamic size from the memref descriptor.
340 MemRefDescriptor descriptor(adaptor.getSource());
341 return descriptor.size(rewriter, loc, i);
343 // Use constant for static size.
344 int64_t dimSize = memRefType.getDimSize(i);
345 return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
348 Value index = adaptor.getIndex();
349 int64_t rank = memRefType.getRank();
350 MemRefDescriptor memrefDescriptor(adaptor.getSource());
351 return memrefDescriptor.size(rewriter, loc, index, rank);
355 /// Common base for load and store operations on MemRefs. Restricts the match
356 /// to supported MemRef types. Provides functionality to emit code accessing a
357 /// specific element of the underlying data buffer.
358 template <typename Derived>
359 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
360 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
361 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
362 using Base = LoadStoreOpLowering<Derived>;
364 LogicalResult match(Derived op) const override {
365 MemRefType type = op.getMemRefType();
366 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
370 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
371 /// retried until it succeeds in atomically storing a new value into memory.
373 /// +---------------------------------+
374 /// | <code before the AtomicRMWOp> |
375 /// | <compute initial %loaded> |
376 /// | cf.br loop(%loaded) |
377 /// +---------------------------------+
378 /// |
379 /// -------| |
380 /// | v v
381 /// | +--------------------------------+
382 /// | | loop(%loaded): |
383 /// | | <body contents> |
384 /// | | %pair = cmpxchg |
385 /// | | %ok = %pair[0] |
386 /// | | %new = %pair[1] |
387 /// | | cf.cond_br %ok, end, loop(%new) |
388 /// | +--------------------------------+
389 /// | | |
390 /// |----------- |
391 /// v
392 /// +--------------------------------+
393 /// | end: |
394 /// | <code after the AtomicRMWOp> |
395 /// +--------------------------------+
397 struct GenericAtomicRMWOpLowering
398 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
399 using Base::Base;
401 LogicalResult
402 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
403 ConversionPatternRewriter &rewriter) const override {
404 auto loc = atomicOp.getLoc();
405 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
407 // Split the block into initial, loop, and ending parts.
408 auto *initBlock = rewriter.getInsertionBlock();
409 auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
410 loopBlock->addArgument(valueType, loc);
412 auto *endBlock =
413 rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
415 // Compute the loaded value and branch to the loop block.
416 rewriter.setInsertionPointToEnd(initBlock);
417 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
418 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
419 adaptor.getIndices(), rewriter);
420 Value init = rewriter.create<LLVM::LoadOp>(
421 loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
422 rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
424 // Prepare the body of the loop block.
425 rewriter.setInsertionPointToStart(loopBlock);
427 // Clone the GenericAtomicRMWOp region and extract the result.
428 auto loopArgument = loopBlock->getArgument(0);
429 IRMapping mapping;
430 mapping.map(atomicOp.getCurrentValue(), loopArgument);
431 Block &entryBlock = atomicOp.body().front();
432 for (auto &nestedOp : entryBlock.without_terminator()) {
433 Operation *clone = rewriter.clone(nestedOp, mapping);
434 mapping.map(nestedOp.getResults(), clone->getResults());
436 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
438 // Prepare the epilog of the loop block.
439 // Append the cmpxchg op to the end of the loop block.
440 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
441 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
442 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
443 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
444 // Extract the %new_loaded and %ok values from the pair.
445 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
446 Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
448 // Conditionally branch to the end or back to the loop depending on %ok.
449 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
450 loopBlock, newLoaded);
452 rewriter.setInsertionPointToEnd(endBlock);
454 // The 'result' of the atomic_rmw op is the newly loaded value.
455 rewriter.replaceOp(atomicOp, {newLoaded});
457 return success();
461 /// Returns the LLVM type of the global variable given the memref type `type`.
462 static Type
463 convertGlobalMemrefTypeToLLVM(MemRefType type,
464 const LLVMTypeConverter &typeConverter) {
465 // LLVM type for a global memref will be a multi-dimension array. For
466 // declarations or uninitialized global memrefs, we can potentially flatten
467 // this to a 1D array. However, for memref.global's with an initial value,
468 // we do not intend to flatten the ElementsAttribute when going from std ->
469 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
470 Type elementType = typeConverter.convertType(type.getElementType());
471 Type arrayTy = elementType;
472 // Shape has the outermost dim at index 0, so need to walk it backwards
473 for (int64_t dim : llvm::reverse(type.getShape()))
474 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
475 return arrayTy;
478 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
479 struct GlobalMemrefOpLowering
480 : public ConvertOpToLLVMPattern<memref::GlobalOp> {
481 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
483 LogicalResult
484 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486 MemRefType type = global.getType();
487 if (!isConvertibleAndHasIdentityMaps(type))
488 return failure();
490 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
492 LLVM::Linkage linkage =
493 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
495 Attribute initialValue = nullptr;
496 if (!global.isExternal() && !global.isUninitialized()) {
497 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
498 initialValue = elementsAttr;
500 // For scalar memrefs, the global variable created is of the element type,
501 // so unpack the elements attribute to extract the value.
502 if (type.getRank() == 0)
503 initialValue = elementsAttr.getSplatValue<Attribute>();
506 uint64_t alignment = global.getAlignment().value_or(0);
507 FailureOr<unsigned> addressSpace =
508 getTypeConverter()->getMemRefAddressSpace(type);
509 if (failed(addressSpace))
510 return global.emitOpError(
511 "memory space cannot be converted to an integer address space");
512 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
513 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
514 initialValue, alignment, *addressSpace);
515 if (!global.isExternal() && global.isUninitialized()) {
516 rewriter.createBlock(&newGlobal.getInitializerRegion());
517 Value undef[] = {
518 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
519 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
521 return success();
525 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
526 /// the first element stashed into the descriptor. This reuses
527 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
528 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
529 GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
530 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
531 converter) {}
533 /// Buffer "allocation" for memref.get_global op is getting the address of
534 /// the global variable referenced.
535 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
536 Location loc, Value sizeBytes,
537 Operation *op) const override {
538 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
539 MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
541 // This is called after a type conversion, which would have failed if this
542 // call fails.
543 FailureOr<unsigned> maybeAddressSpace =
544 getTypeConverter()->getMemRefAddressSpace(type);
545 if (failed(maybeAddressSpace))
546 return std::make_tuple(Value(), Value());
547 unsigned memSpace = *maybeAddressSpace;
549 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
550 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
551 auto addressOf =
552 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
554 // Get the address of the first element in the array by creating a GEP with
555 // the address of the GV as the base, and (rank + 1) number of 0 indices.
556 auto gep = rewriter.create<LLVM::GEPOp>(
557 loc, ptrTy, arrayTy, addressOf,
558 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
560 // We do not expect the memref obtained using `memref.get_global` to be
561 // ever deallocated. Set the allocated pointer to be known bad value to
562 // help debug if that ever happens.
563 auto intPtrType = getIntPtrType(memSpace);
564 Value deadBeefConst =
565 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
566 auto deadBeefPtr =
567 rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
569 // Both allocated and aligned pointers are same. We could potentially stash
570 // a nullptr for the allocated pointer since we do not expect any dealloc.
571 return std::make_tuple(deadBeefPtr, gep);
575 // Load operation is lowered to obtaining a pointer to the indexed element
576 // and loading it.
577 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
578 using Base::Base;
580 LogicalResult
581 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
582 ConversionPatternRewriter &rewriter) const override {
583 auto type = loadOp.getMemRefType();
585 Value dataPtr =
586 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
587 adaptor.getIndices(), rewriter);
588 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
589 loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
590 false, loadOp.getNontemporal());
591 return success();
595 // Store operation is lowered to obtaining a pointer to the indexed element,
596 // and storing the given value to it.
597 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
598 using Base::Base;
600 LogicalResult
601 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
602 ConversionPatternRewriter &rewriter) const override {
603 auto type = op.getMemRefType();
605 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
606 adaptor.getIndices(), rewriter);
607 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
608 0, false, op.getNontemporal());
609 return success();
613 // The prefetch operation is lowered in a way similar to the load operation
614 // except that the llvm.prefetch operation is used for replacement.
615 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
616 using Base::Base;
618 LogicalResult
619 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
620 ConversionPatternRewriter &rewriter) const override {
621 auto type = prefetchOp.getMemRefType();
622 auto loc = prefetchOp.getLoc();
624 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
625 adaptor.getIndices(), rewriter);
627 // Replace with llvm.prefetch.
628 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
629 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
630 IntegerAttr isData =
631 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
632 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
633 localityHint, isData);
634 return success();
638 struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
639 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
641 LogicalResult
642 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
643 ConversionPatternRewriter &rewriter) const override {
644 Location loc = op.getLoc();
645 Type operandType = op.getMemref().getType();
646 if (dyn_cast<UnrankedMemRefType>(operandType)) {
647 UnrankedMemRefDescriptor desc(adaptor.getMemref());
648 rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
649 return success();
651 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
652 Type indexType = getIndexType();
653 rewriter.replaceOp(op,
654 {createIndexAttrConstant(rewriter, loc, indexType,
655 rankedMemRefType.getRank())});
656 return success();
658 return failure();
662 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
663 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
665 LogicalResult match(memref::CastOp memRefCastOp) const override {
666 Type srcType = memRefCastOp.getOperand().getType();
667 Type dstType = memRefCastOp.getType();
669 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
670 // used for type erasure. For now they must preserve underlying element type
671 // and require source and result type to have the same rank. Therefore,
672 // perform a sanity check that the underlying structs are the same. Once op
673 // semantics are relaxed we can revisit.
674 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
675 return success(typeConverter->convertType(srcType) ==
676 typeConverter->convertType(dstType));
678 // At least one of the operands is unranked type
679 assert(isa<UnrankedMemRefType>(srcType) ||
680 isa<UnrankedMemRefType>(dstType));
682 // Unranked to unranked cast is disallowed
683 return !(isa<UnrankedMemRefType>(srcType) &&
684 isa<UnrankedMemRefType>(dstType))
685 ? success()
686 : failure();
689 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
690 ConversionPatternRewriter &rewriter) const override {
691 auto srcType = memRefCastOp.getOperand().getType();
692 auto dstType = memRefCastOp.getType();
693 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
694 auto loc = memRefCastOp.getLoc();
696 // For ranked/ranked case, just keep the original descriptor.
697 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
698 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
700 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
701 // Casting ranked to unranked memref type
702 // Set the rank in the destination from the memref type
703 // Allocate space on the stack and copy the src memref descriptor
704 // Set the ptr in the destination to the stack space
705 auto srcMemRefType = cast<MemRefType>(srcType);
706 int64_t rank = srcMemRefType.getRank();
707 // ptr = AllocaOp sizeof(MemRefDescriptor)
708 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
709 loc, adaptor.getSource(), rewriter);
711 // rank = ConstantOp srcRank
712 auto rankVal = rewriter.create<LLVM::ConstantOp>(
713 loc, getIndexType(), rewriter.getIndexAttr(rank));
714 // undef = UndefOp
715 UnrankedMemRefDescriptor memRefDesc =
716 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
717 // d1 = InsertValueOp undef, rank, 0
718 memRefDesc.setRank(rewriter, loc, rankVal);
719 // d2 = InsertValueOp d1, ptr, 1
720 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
721 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
723 } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
724 // Casting from unranked type to ranked.
725 // The operation is assumed to be doing a correct cast. If the destination
726 // type mismatches the unranked the type, it is undefined behavior.
727 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
728 // ptr = ExtractValueOp src, 1
729 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
731 // struct = LoadOp ptr
732 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
733 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
734 } else {
735 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
740 /// Pattern to lower a `memref.copy` to llvm.
742 /// For memrefs with identity layouts, the copy is lowered to the llvm
743 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
744 /// to the generic `MemrefCopyFn`.
745 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
746 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
748 LogicalResult
749 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
750 ConversionPatternRewriter &rewriter) const {
751 auto loc = op.getLoc();
752 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
754 MemRefDescriptor srcDesc(adaptor.getSource());
756 // Compute number of elements.
757 Value numElements = rewriter.create<LLVM::ConstantOp>(
758 loc, getIndexType(), rewriter.getIndexAttr(1));
759 for (int pos = 0; pos < srcType.getRank(); ++pos) {
760 auto size = srcDesc.size(rewriter, loc, pos);
761 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
764 // Get element size.
765 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
766 // Compute total.
767 Value totalSize =
768 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
770 Type elementType = typeConverter->convertType(srcType.getElementType());
772 Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
773 Value srcOffset = srcDesc.offset(rewriter, loc);
774 Value srcPtr = rewriter.create<LLVM::GEPOp>(
775 loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
776 MemRefDescriptor targetDesc(adaptor.getTarget());
777 Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
778 Value targetOffset = targetDesc.offset(rewriter, loc);
779 Value targetPtr = rewriter.create<LLVM::GEPOp>(
780 loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
781 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
782 /*isVolatile=*/false);
783 rewriter.eraseOp(op);
785 return success();
788 LogicalResult
789 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
790 ConversionPatternRewriter &rewriter) const {
791 auto loc = op.getLoc();
792 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
793 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
795 // First make sure we have an unranked memref descriptor representation.
796 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
797 auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
798 type.getRank());
799 auto *typeConverter = getTypeConverter();
800 auto ptr =
801 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
803 auto unrankedType =
804 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
805 return UnrankedMemRefDescriptor::pack(
806 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
809 // Save stack position before promoting descriptors
810 auto stackSaveOp =
811 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
813 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
814 Value unrankedSource =
815 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
816 : adaptor.getSource();
817 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
818 Value unrankedTarget =
819 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
820 : adaptor.getTarget();
822 // Now promote the unranked descriptors to the stack.
823 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
824 rewriter.getIndexAttr(1));
825 auto promote = [&](Value desc) {
826 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
827 auto allocated =
828 rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
829 rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
830 return allocated;
833 auto sourcePtr = promote(unrankedSource);
834 auto targetPtr = promote(unrankedTarget);
836 // Derive size from llvm.getelementptr which will account for any
837 // potential alignment
838 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
839 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
840 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
841 rewriter.create<LLVM::CallOp>(loc, copyFn,
842 ValueRange{elemSize, sourcePtr, targetPtr});
844 // Restore stack used for descriptors
845 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
847 rewriter.eraseOp(op);
849 return success();
852 LogicalResult
853 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter) const override {
855 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
856 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
858 auto isContiguousMemrefType = [&](BaseMemRefType type) {
859 auto memrefType = dyn_cast<mlir::MemRefType>(type);
860 // We can use memcpy for memrefs if they have an identity layout or are
861 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
862 // special case handled by memrefCopy.
863 return memrefType &&
864 (memrefType.getLayout().isIdentity() ||
865 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
866 memref::isStaticShapeAndContiguousRowMajor(memrefType)));
869 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
870 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
872 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
876 struct MemorySpaceCastOpLowering
877 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
878 using ConvertOpToLLVMPattern<
879 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
881 LogicalResult
882 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
883 ConversionPatternRewriter &rewriter) const override {
884 Location loc = op.getLoc();
886 Type resultType = op.getDest().getType();
887 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
888 auto resultDescType =
889 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
890 Type newPtrType = resultDescType.getBody()[0];
892 SmallVector<Value> descVals;
893 MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
894 descVals);
895 descVals[0] =
896 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
897 descVals[1] =
898 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
899 Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
900 resultTypeR, descVals);
901 rewriter.replaceOp(op, result);
902 return success();
904 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
905 // Since the type converter won't be doing this for us, get the address
906 // space.
907 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
908 FailureOr<unsigned> maybeSourceAddrSpace =
909 getTypeConverter()->getMemRefAddressSpace(sourceType);
910 if (failed(maybeSourceAddrSpace))
911 return rewriter.notifyMatchFailure(loc,
912 "non-integer source address space");
913 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
914 FailureOr<unsigned> maybeResultAddrSpace =
915 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
916 if (failed(maybeResultAddrSpace))
917 return rewriter.notifyMatchFailure(loc,
918 "non-integer result address space");
919 unsigned resultAddrSpace = *maybeResultAddrSpace;
921 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
922 Value rank = sourceDesc.rank(rewriter, loc);
923 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
925 // Create and allocate storage for new memref descriptor.
926 auto result = UnrankedMemRefDescriptor::undef(
927 rewriter, loc, typeConverter->convertType(resultTypeU));
928 result.setRank(rewriter, loc, rank);
929 SmallVector<Value, 1> sizes;
930 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
931 result, resultAddrSpace, sizes);
932 Value resultUnderlyingSize = sizes.front();
933 Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
934 loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
935 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
937 // Copy pointers, performing address space casts.
938 auto sourceElemPtrType =
939 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
940 auto resultElemPtrType =
941 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
943 Value allocatedPtr = sourceDesc.allocatedPtr(
944 rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
945 Value alignedPtr =
946 sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
947 sourceUnderlyingDesc, sourceElemPtrType);
948 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
949 loc, resultElemPtrType, allocatedPtr);
950 alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
951 loc, resultElemPtrType, alignedPtr);
953 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
954 resultElemPtrType, allocatedPtr);
955 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
956 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
958 // Copy all the index-valued operands.
959 Value sourceIndexVals =
960 sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
961 sourceUnderlyingDesc, sourceElemPtrType);
962 Value resultIndexVals =
963 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
964 resultUnderlyingDesc, resultElemPtrType);
966 int64_t bytesToSkip =
967 2 * llvm::divideCeil(
968 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
969 Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
970 loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
971 Value copySize = rewriter.create<LLVM::SubOp>(
972 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
973 rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
974 copySize, /*isVolatile=*/false);
976 rewriter.replaceOp(op, ValueRange{result});
977 return success();
979 return rewriter.notifyMatchFailure(loc, "unexpected memref type");
983 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
984 /// memref type. In unranked case, the fields are extracted from the underlying
985 /// ranked descriptor.
986 static void extractPointersAndOffset(Location loc,
987 ConversionPatternRewriter &rewriter,
988 const LLVMTypeConverter &typeConverter,
989 Value originalOperand,
990 Value convertedOperand,
991 Value *allocatedPtr, Value *alignedPtr,
992 Value *offset = nullptr) {
993 Type operandType = originalOperand.getType();
994 if (isa<MemRefType>(operandType)) {
995 MemRefDescriptor desc(convertedOperand);
996 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
997 *alignedPtr = desc.alignedPtr(rewriter, loc);
998 if (offset != nullptr)
999 *offset = desc.offset(rewriter, loc);
1000 return;
1003 // These will all cause assert()s on unconvertible types.
1004 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1005 cast<UnrankedMemRefType>(operandType));
1006 auto elementPtrType =
1007 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1009 // Extract pointer to the underlying ranked memref descriptor and cast it to
1010 // ElemType**.
1011 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1012 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
1014 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1015 rewriter, loc, underlyingDescPtr, elementPtrType);
1016 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1017 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1018 if (offset != nullptr) {
1019 *offset = UnrankedMemRefDescriptor::offset(
1020 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
1024 struct MemRefReinterpretCastOpLowering
1025 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1026 using ConvertOpToLLVMPattern<
1027 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1029 LogicalResult
1030 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter) const override {
1032 Type srcType = castOp.getSource().getType();
1034 Value descriptor;
1035 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1036 adaptor, &descriptor)))
1037 return failure();
1038 rewriter.replaceOp(castOp, {descriptor});
1039 return success();
1042 private:
1043 LogicalResult convertSourceMemRefToDescriptor(
1044 ConversionPatternRewriter &rewriter, Type srcType,
1045 memref::ReinterpretCastOp castOp,
1046 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1047 MemRefType targetMemRefType =
1048 cast<MemRefType>(castOp.getResult().getType());
1049 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1050 typeConverter->convertType(targetMemRefType));
1051 if (!llvmTargetDescriptorTy)
1052 return failure();
1054 // Create descriptor.
1055 Location loc = castOp.getLoc();
1056 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1058 // Set allocated and aligned pointers.
1059 Value allocatedPtr, alignedPtr;
1060 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1061 castOp.getSource(), adaptor.getSource(),
1062 &allocatedPtr, &alignedPtr);
1063 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1064 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1066 // Set offset.
1067 if (castOp.isDynamicOffset(0))
1068 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1069 else
1070 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1072 // Set sizes and strides.
1073 unsigned dynSizeId = 0;
1074 unsigned dynStrideId = 0;
1075 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1076 if (castOp.isDynamicSize(i))
1077 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1078 else
1079 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1081 if (castOp.isDynamicStride(i))
1082 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1083 else
1084 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1086 *descriptor = desc;
1087 return success();
1091 struct MemRefReshapeOpLowering
1092 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1093 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1095 LogicalResult
1096 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1097 ConversionPatternRewriter &rewriter) const override {
1098 Type srcType = reshapeOp.getSource().getType();
1100 Value descriptor;
1101 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1102 adaptor, &descriptor)))
1103 return failure();
1104 rewriter.replaceOp(reshapeOp, {descriptor});
1105 return success();
1108 private:
1109 LogicalResult
1110 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1111 Type srcType, memref::ReshapeOp reshapeOp,
1112 memref::ReshapeOp::Adaptor adaptor,
1113 Value *descriptor) const {
1114 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1115 if (shapeMemRefType.hasStaticShape()) {
1116 MemRefType targetMemRefType =
1117 cast<MemRefType>(reshapeOp.getResult().getType());
1118 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1119 typeConverter->convertType(targetMemRefType));
1120 if (!llvmTargetDescriptorTy)
1121 return failure();
1123 // Create descriptor.
1124 Location loc = reshapeOp.getLoc();
1125 auto desc =
1126 MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1128 // Set allocated and aligned pointers.
1129 Value allocatedPtr, alignedPtr;
1130 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1131 reshapeOp.getSource(), adaptor.getSource(),
1132 &allocatedPtr, &alignedPtr);
1133 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1134 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1136 // Extract the offset and strides from the type.
1137 int64_t offset;
1138 SmallVector<int64_t> strides;
1139 if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1140 return rewriter.notifyMatchFailure(
1141 reshapeOp, "failed to get stride and offset exprs");
1143 if (!isStaticStrideOrOffset(offset))
1144 return rewriter.notifyMatchFailure(reshapeOp,
1145 "dynamic offset is unsupported");
1147 desc.setConstantOffset(rewriter, loc, offset);
1149 assert(targetMemRefType.getLayout().isIdentity() &&
1150 "Identity layout map is a precondition of a valid reshape op");
1152 Type indexType = getIndexType();
1153 Value stride = nullptr;
1154 int64_t targetRank = targetMemRefType.getRank();
1155 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1156 if (!ShapedType::isDynamic(strides[i])) {
1157 // If the stride for this dimension is dynamic, then use the product
1158 // of the sizes of the inner dimensions.
1159 stride =
1160 createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1161 } else if (!stride) {
1162 // `stride` is null only in the first iteration of the loop. However,
1163 // since the target memref has an identity layout, we can safely set
1164 // the innermost stride to 1.
1165 stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1168 Value dimSize;
1169 // If the size of this dimension is dynamic, then load it at runtime
1170 // from the shape operand.
1171 if (!targetMemRefType.isDynamicDim(i)) {
1172 dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1173 targetMemRefType.getDimSize(i));
1174 } else {
1175 Value shapeOp = reshapeOp.getShape();
1176 Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1177 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1178 Type indexType = getIndexType();
1179 if (dimSize.getType() != indexType)
1180 dimSize = typeConverter->materializeTargetConversion(
1181 rewriter, loc, indexType, dimSize);
1182 assert(dimSize && "Invalid memref element type");
1185 desc.setSize(rewriter, loc, i, dimSize);
1186 desc.setStride(rewriter, loc, i, stride);
1188 // Prepare the stride value for the next dimension.
1189 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1192 *descriptor = desc;
1193 return success();
1196 // The shape is a rank-1 tensor with unknown length.
1197 Location loc = reshapeOp.getLoc();
1198 MemRefDescriptor shapeDesc(adaptor.getShape());
1199 Value resultRank = shapeDesc.size(rewriter, loc, 0);
1201 // Extract address space and element type.
1202 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1203 unsigned addressSpace =
1204 *getTypeConverter()->getMemRefAddressSpace(targetType);
1206 // Create the unranked memref descriptor that holds the ranked one. The
1207 // inner descriptor is allocated on stack.
1208 auto targetDesc = UnrankedMemRefDescriptor::undef(
1209 rewriter, loc, typeConverter->convertType(targetType));
1210 targetDesc.setRank(rewriter, loc, resultRank);
1211 SmallVector<Value, 4> sizes;
1212 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
1213 targetDesc, addressSpace, sizes);
1214 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1215 loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1216 sizes.front());
1217 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1219 // Extract pointers and offset from the source memref.
1220 Value allocatedPtr, alignedPtr, offset;
1221 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1222 reshapeOp.getSource(), adaptor.getSource(),
1223 &allocatedPtr, &alignedPtr, &offset);
1225 // Set pointers and offset.
1226 auto elementPtrType =
1227 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1229 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
1230 elementPtrType, allocatedPtr);
1231 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
1232 underlyingDescPtr, elementPtrType,
1233 alignedPtr);
1234 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
1235 underlyingDescPtr, elementPtrType,
1236 offset);
1238 // Use the offset pointer as base for further addressing. Copy over the new
1239 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1240 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1241 rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
1242 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1243 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
1244 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
1245 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1246 Value resultRankMinusOne =
1247 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1249 Block *initBlock = rewriter.getInsertionBlock();
1250 Type indexType = getTypeConverter()->getIndexType();
1251 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1253 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1254 {indexType, indexType}, {loc, loc});
1256 // Move the remaining initBlock ops to condBlock.
1257 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
1258 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
1260 rewriter.setInsertionPointToEnd(initBlock);
1261 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1262 condBlock);
1263 rewriter.setInsertionPointToStart(condBlock);
1264 Value indexArg = condBlock->getArgument(0);
1265 Value strideArg = condBlock->getArgument(1);
1267 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1268 Value pred = rewriter.create<LLVM::ICmpOp>(
1269 loc, IntegerType::get(rewriter.getContext(), 1),
1270 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1272 Block *bodyBlock =
1273 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
1274 rewriter.setInsertionPointToStart(bodyBlock);
1276 // Copy size from shape to descriptor.
1277 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1278 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1279 loc, llvmIndexPtrType,
1280 typeConverter->convertType(shapeMemRefType.getElementType()),
1281 shapeOperandPtr, indexArg);
1282 Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1283 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
1284 targetSizesBase, indexArg, size);
1286 // Write stride value and compute next one.
1287 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
1288 targetStridesBase, indexArg, strideArg);
1289 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1291 // Decrement loop counter and branch back.
1292 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1293 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1294 condBlock);
1296 Block *remainder =
1297 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
1299 // Hook up the cond exit to the remainder.
1300 rewriter.setInsertionPointToEnd(condBlock);
1301 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1302 remainder, std::nullopt);
1304 // Reset position to beginning of new remainder block.
1305 rewriter.setInsertionPointToStart(remainder);
1307 *descriptor = targetDesc;
1308 return success();
1312 /// RessociatingReshapeOp must be expanded before we reach this stage.
1313 /// Report that information.
1314 template <typename ReshapeOp>
1315 class ReassociatingReshapeOpConversion
1316 : public ConvertOpToLLVMPattern<ReshapeOp> {
1317 public:
1318 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1319 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1321 LogicalResult
1322 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1323 ConversionPatternRewriter &rewriter) const override {
1324 return rewriter.notifyMatchFailure(
1325 reshapeOp,
1326 "reassociation operations should have been expanded beforehand");
1330 /// Subviews must be expanded before we reach this stage.
1331 /// Report that information.
1332 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1333 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1335 LogicalResult
1336 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1337 ConversionPatternRewriter &rewriter) const override {
1338 return rewriter.notifyMatchFailure(
1339 subViewOp, "subview operations should have been expanded beforehand");
1343 /// Conversion pattern that transforms a transpose op into:
1344 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1345 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1346 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1347 /// and stride. Size and stride are permutations of the original values.
1348 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1349 /// The transpose op is replaced by the alloca'ed pointer.
1350 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1351 public:
1352 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1354 LogicalResult
1355 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1356 ConversionPatternRewriter &rewriter) const override {
1357 auto loc = transposeOp.getLoc();
1358 MemRefDescriptor viewMemRef(adaptor.getIn());
1360 // No permutation, early exit.
1361 if (transposeOp.getPermutation().isIdentity())
1362 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1364 auto targetMemRef = MemRefDescriptor::undef(
1365 rewriter, loc,
1366 typeConverter->convertType(transposeOp.getIn().getType()));
1368 // Copy the base and aligned pointers from the old descriptor to the new
1369 // one.
1370 targetMemRef.setAllocatedPtr(rewriter, loc,
1371 viewMemRef.allocatedPtr(rewriter, loc));
1372 targetMemRef.setAlignedPtr(rewriter, loc,
1373 viewMemRef.alignedPtr(rewriter, loc));
1375 // Copy the offset pointer from the old descriptor to the new one.
1376 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1378 // Iterate over the dimensions and apply size/stride permutation:
1379 // When enumerating the results of the permutation map, the enumeration
1380 // index is the index into the target dimensions and the DimExpr points to
1381 // the dimension of the source memref.
1382 for (const auto &en :
1383 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1384 int targetPos = en.index();
1385 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1386 targetMemRef.setSize(rewriter, loc, targetPos,
1387 viewMemRef.size(rewriter, loc, sourcePos));
1388 targetMemRef.setStride(rewriter, loc, targetPos,
1389 viewMemRef.stride(rewriter, loc, sourcePos));
1392 rewriter.replaceOp(transposeOp, {targetMemRef});
1393 return success();
1397 /// Conversion pattern that transforms an op into:
1398 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1399 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1400 /// and stride.
1401 /// The view op is replaced by the descriptor.
1402 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1403 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1405 // Build and return the value for the idx^th shape dimension, either by
1406 // returning the constant shape dimension or counting the proper dynamic size.
1407 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1408 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1409 Type indexType) const {
1410 assert(idx < shape.size());
1411 if (!ShapedType::isDynamic(shape[idx]))
1412 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1413 // Count the number of dynamic dims in range [0, idx]
1414 unsigned nDynamic =
1415 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1416 return dynamicSizes[nDynamic];
1419 // Build and return the idx^th stride, either by returning the constant stride
1420 // or by computing the dynamic stride from the current `runningStride` and
1421 // `nextSize`. The caller should keep a running stride and update it with the
1422 // result returned by this function.
1423 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1424 ArrayRef<int64_t> strides, Value nextSize,
1425 Value runningStride, unsigned idx, Type indexType) const {
1426 assert(idx < strides.size());
1427 if (!ShapedType::isDynamic(strides[idx]))
1428 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1429 if (nextSize)
1430 return runningStride
1431 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1432 : nextSize;
1433 assert(!runningStride);
1434 return createIndexAttrConstant(rewriter, loc, indexType, 1);
1437 LogicalResult
1438 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1439 ConversionPatternRewriter &rewriter) const override {
1440 auto loc = viewOp.getLoc();
1442 auto viewMemRefType = viewOp.getType();
1443 auto targetElementTy =
1444 typeConverter->convertType(viewMemRefType.getElementType());
1445 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1446 if (!targetDescTy || !targetElementTy ||
1447 !LLVM::isCompatibleType(targetElementTy) ||
1448 !LLVM::isCompatibleType(targetDescTy))
1449 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1450 failure();
1452 int64_t offset;
1453 SmallVector<int64_t, 4> strides;
1454 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1455 if (failed(successStrides))
1456 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1457 assert(offset == 0 && "expected offset to be 0");
1459 // Target memref must be contiguous in memory (innermost stride is 1), or
1460 // empty (special case when at least one of the memref dimensions is 0).
1461 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1462 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1463 failure();
1465 // Create the descriptor.
1466 MemRefDescriptor sourceMemRef(adaptor.getSource());
1467 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1469 // Field 1: Copy the allocated pointer, used for malloc/free.
1470 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1471 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1472 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1474 // Field 2: Copy the actual aligned pointer to payload.
1475 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1476 alignedPtr = rewriter.create<LLVM::GEPOp>(
1477 loc, alignedPtr.getType(),
1478 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1479 adaptor.getByteShift());
1481 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1483 Type indexType = getIndexType();
1484 // Field 3: The offset in the resulting type must be 0. This is
1485 // because of the type change: an offset on srcType* may not be
1486 // expressible as an offset on dstType*.
1487 targetMemRef.setOffset(
1488 rewriter, loc,
1489 createIndexAttrConstant(rewriter, loc, indexType, offset));
1491 // Early exit for 0-D corner case.
1492 if (viewMemRefType.getRank() == 0)
1493 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1495 // Fields 4 and 5: Update sizes and strides.
1496 Value stride = nullptr, nextSize = nullptr;
1497 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1498 // Update size.
1499 Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
1500 adaptor.getSizes(), i, indexType);
1501 targetMemRef.setSize(rewriter, loc, i, size);
1502 // Update stride.
1503 stride =
1504 getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
1505 targetMemRef.setStride(rewriter, loc, i, stride);
1506 nextSize = size;
1509 rewriter.replaceOp(viewOp, {targetMemRef});
1510 return success();
1514 //===----------------------------------------------------------------------===//
1515 // AtomicRMWOpLowering
1516 //===----------------------------------------------------------------------===//
1518 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1519 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1520 static std::optional<LLVM::AtomicBinOp>
1521 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1522 switch (atomicOp.getKind()) {
1523 case arith::AtomicRMWKind::addf:
1524 return LLVM::AtomicBinOp::fadd;
1525 case arith::AtomicRMWKind::addi:
1526 return LLVM::AtomicBinOp::add;
1527 case arith::AtomicRMWKind::assign:
1528 return LLVM::AtomicBinOp::xchg;
1529 case arith::AtomicRMWKind::maximumf:
1530 return LLVM::AtomicBinOp::fmax;
1531 case arith::AtomicRMWKind::maxs:
1532 return LLVM::AtomicBinOp::max;
1533 case arith::AtomicRMWKind::maxu:
1534 return LLVM::AtomicBinOp::umax;
1535 case arith::AtomicRMWKind::minimumf:
1536 return LLVM::AtomicBinOp::fmin;
1537 case arith::AtomicRMWKind::mins:
1538 return LLVM::AtomicBinOp::min;
1539 case arith::AtomicRMWKind::minu:
1540 return LLVM::AtomicBinOp::umin;
1541 case arith::AtomicRMWKind::ori:
1542 return LLVM::AtomicBinOp::_or;
1543 case arith::AtomicRMWKind::andi:
1544 return LLVM::AtomicBinOp::_and;
1545 default:
1546 return std::nullopt;
1548 llvm_unreachable("Invalid AtomicRMWKind");
1551 struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1552 using Base::Base;
1554 LogicalResult
1555 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1556 ConversionPatternRewriter &rewriter) const override {
1557 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1558 if (!maybeKind)
1559 return failure();
1560 auto memRefType = atomicOp.getMemRefType();
1561 SmallVector<int64_t> strides;
1562 int64_t offset;
1563 if (failed(getStridesAndOffset(memRefType, strides, offset)))
1564 return failure();
1565 auto dataPtr =
1566 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1567 adaptor.getIndices(), rewriter);
1568 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1569 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1570 LLVM::AtomicOrdering::acq_rel);
1571 return success();
1575 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1576 class ConvertExtractAlignedPointerAsIndex
1577 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1578 public:
1579 using ConvertOpToLLVMPattern<
1580 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1582 LogicalResult
1583 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1584 OpAdaptor adaptor,
1585 ConversionPatternRewriter &rewriter) const override {
1586 BaseMemRefType sourceTy = extractOp.getSource().getType();
1588 Value alignedPtr;
1589 if (sourceTy.hasRank()) {
1590 MemRefDescriptor desc(adaptor.getSource());
1591 alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
1592 } else {
1593 auto elementPtrTy = LLVM::LLVMPointerType::get(
1594 rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1596 UnrankedMemRefDescriptor desc(adaptor.getSource());
1597 Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
1599 alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1600 rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
1601 elementPtrTy);
1604 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1605 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1606 return success();
1610 /// Materialize the MemRef descriptor represented by the results of
1611 /// ExtractStridedMetadataOp.
1612 class ExtractStridedMetadataOpLowering
1613 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1614 public:
1615 using ConvertOpToLLVMPattern<
1616 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1618 LogicalResult
1619 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1620 OpAdaptor adaptor,
1621 ConversionPatternRewriter &rewriter) const override {
1623 if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
1624 return failure();
1626 // Create the descriptor.
1627 MemRefDescriptor sourceMemRef(adaptor.getSource());
1628 Location loc = extractStridedMetadataOp.getLoc();
1629 Value source = extractStridedMetadataOp.getSource();
1631 auto sourceMemRefType = cast<MemRefType>(source.getType());
1632 int64_t rank = sourceMemRefType.getRank();
1633 SmallVector<Value> results;
1634 results.reserve(2 + rank * 2);
1636 // Base buffer.
1637 Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
1638 Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
1639 MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
1640 rewriter, loc, *getTypeConverter(),
1641 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1642 baseBuffer, alignedBuffer);
1643 results.push_back((Value)dstMemRef);
1645 // Offset.
1646 results.push_back(sourceMemRef.offset(rewriter, loc));
1648 // Sizes.
1649 for (unsigned i = 0; i < rank; ++i)
1650 results.push_back(sourceMemRef.size(rewriter, loc, i));
1651 // Strides.
1652 for (unsigned i = 0; i < rank; ++i)
1653 results.push_back(sourceMemRef.stride(rewriter, loc, i));
1655 rewriter.replaceOp(extractStridedMetadataOp, results);
1656 return success();
1660 } // namespace
1662 void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1663 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1664 // clang-format off
1665 patterns.add<
1666 AllocaOpLowering,
1667 AllocaScopeOpLowering,
1668 AtomicRMWOpLowering,
1669 AssumeAlignmentOpLowering,
1670 ConvertExtractAlignedPointerAsIndex,
1671 DimOpLowering,
1672 ExtractStridedMetadataOpLowering,
1673 GenericAtomicRMWOpLowering,
1674 GlobalMemrefOpLowering,
1675 GetGlobalMemrefOpLowering,
1676 LoadOpLowering,
1677 MemRefCastOpLowering,
1678 MemRefCopyOpLowering,
1679 MemorySpaceCastOpLowering,
1680 MemRefReinterpretCastOpLowering,
1681 MemRefReshapeOpLowering,
1682 PrefetchOpLowering,
1683 RankOpLowering,
1684 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1685 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1686 StoreOpLowering,
1687 SubViewOpLowering,
1688 TransposeOpLowering,
1689 ViewOpLowering>(converter);
1690 // clang-format on
1691 auto allocLowering = converter.getOptions().allocLowering;
1692 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1693 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1694 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1695 patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1698 namespace {
1699 struct FinalizeMemRefToLLVMConversionPass
1700 : public impl::FinalizeMemRefToLLVMConversionPassBase<
1701 FinalizeMemRefToLLVMConversionPass> {
1702 using FinalizeMemRefToLLVMConversionPassBase::
1703 FinalizeMemRefToLLVMConversionPassBase;
1705 void runOnOperation() override {
1706 Operation *op = getOperation();
1707 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1708 LowerToLLVMOptions options(&getContext(),
1709 dataLayoutAnalysis.getAtOrAbove(op));
1710 options.allocLowering =
1711 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1712 : LowerToLLVMOptions::AllocLowering::Malloc);
1714 options.useGenericFunctions = useGenericFunctions;
1716 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1717 options.overrideIndexBitwidth(indexBitwidth);
1719 LLVMTypeConverter typeConverter(&getContext(), options,
1720 &dataLayoutAnalysis);
1721 RewritePatternSet patterns(&getContext());
1722 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1723 LLVMConversionTarget target(getContext());
1724 target.addLegalOp<func::FuncOp>();
1725 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1726 signalPassFailure();
1730 /// Implement the interface to convert MemRef to LLVM.
1731 struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1732 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
1733 void loadDependentDialects(MLIRContext *context) const final {
1734 context->loadDialect<LLVM::LLVMDialect>();
1737 /// Hook for derived dialect interface to provide conversion patterns
1738 /// and mark dialect legal for the conversion target.
1739 void populateConvertToLLVMConversionPatterns(
1740 ConversionTarget &target, LLVMTypeConverter &typeConverter,
1741 RewritePatternSet &patterns) const final {
1742 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
1746 } // namespace
1748 void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
1749 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1750 dialect->addInterfaces<MemRefToLLVMDialectInterface>();