1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
41 bool isStaticStrideOrOffset(int64_t strideOrOffset
) {
42 return !ShapedType::isDynamic(strideOrOffset
);
45 LLVM::LLVMFuncOp
getFreeFn(const LLVMTypeConverter
*typeConverter
,
47 bool useGenericFn
= typeConverter
->getOptions().useGenericFunctions
;
50 return LLVM::lookupOrCreateGenericFreeFn(module
);
52 return LLVM::lookupOrCreateFreeFn(module
);
55 struct AllocOpLowering
: public AllocLikeOpLLVMLowering
{
56 AllocOpLowering(const LLVMTypeConverter
&converter
)
57 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
59 std::tuple
<Value
, Value
> allocateBuffer(ConversionPatternRewriter
&rewriter
,
60 Location loc
, Value sizeBytes
,
61 Operation
*op
) const override
{
62 return allocateBufferManuallyAlign(
63 rewriter
, loc
, sizeBytes
, op
,
64 getAlignment(rewriter
, loc
, cast
<memref::AllocOp
>(op
)));
68 struct AlignedAllocOpLowering
: public AllocLikeOpLLVMLowering
{
69 AlignedAllocOpLowering(const LLVMTypeConverter
&converter
)
70 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
72 std::tuple
<Value
, Value
> allocateBuffer(ConversionPatternRewriter
&rewriter
,
73 Location loc
, Value sizeBytes
,
74 Operation
*op
) const override
{
75 Value ptr
= allocateBufferAutoAlign(
76 rewriter
, loc
, sizeBytes
, op
, &defaultLayout
,
77 alignedAllocationGetAlignment(rewriter
, loc
, cast
<memref::AllocOp
>(op
),
80 return std::make_tuple(Value(), Value());
81 return std::make_tuple(ptr
, ptr
);
85 /// Default layout to use in absence of the corresponding analysis.
86 DataLayout defaultLayout
;
89 struct AllocaOpLowering
: public AllocLikeOpLLVMLowering
{
90 AllocaOpLowering(const LLVMTypeConverter
&converter
)
91 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
93 setRequiresNumElements();
96 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97 /// is set to null for stack allocations. `accessAlignment` is set if
98 /// alignment is needed post allocation (for eg. in conjunction with malloc).
99 std::tuple
<Value
, Value
> allocateBuffer(ConversionPatternRewriter
&rewriter
,
100 Location loc
, Value size
,
101 Operation
*op
) const override
{
103 // With alloca, one gets a pointer to the element type right away.
104 // For stack allocations.
105 auto allocaOp
= cast
<memref::AllocaOp
>(op
);
107 typeConverter
->convertType(allocaOp
.getType().getElementType());
109 *getTypeConverter()->getMemRefAddressSpace(allocaOp
.getType());
110 auto elementPtrType
=
111 LLVM::LLVMPointerType::get(rewriter
.getContext(), addrSpace
);
113 auto allocatedElementPtr
=
114 rewriter
.create
<LLVM::AllocaOp
>(loc
, elementPtrType
, elementType
, size
,
115 allocaOp
.getAlignment().value_or(0));
117 return std::make_tuple(allocatedElementPtr
, allocatedElementPtr
);
121 struct AllocaScopeOpLowering
122 : public ConvertOpToLLVMPattern
<memref::AllocaScopeOp
> {
123 using ConvertOpToLLVMPattern
<memref::AllocaScopeOp
>::ConvertOpToLLVMPattern
;
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp
, OpAdaptor adaptor
,
127 ConversionPatternRewriter
&rewriter
) const override
{
128 OpBuilder::InsertionGuard
guard(rewriter
);
129 Location loc
= allocaScopeOp
.getLoc();
131 // Split the current block before the AllocaScopeOp to create the inlining
133 auto *currentBlock
= rewriter
.getInsertionBlock();
134 auto *remainingOpsBlock
=
135 rewriter
.splitBlock(currentBlock
, rewriter
.getInsertionPoint());
136 Block
*continueBlock
;
137 if (allocaScopeOp
.getNumResults() == 0) {
138 continueBlock
= remainingOpsBlock
;
140 continueBlock
= rewriter
.createBlock(
141 remainingOpsBlock
, allocaScopeOp
.getResultTypes(),
142 SmallVector
<Location
>(allocaScopeOp
->getNumResults(),
143 allocaScopeOp
.getLoc()));
144 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange(), remainingOpsBlock
);
147 // Inline body region.
148 Block
*beforeBody
= &allocaScopeOp
.getBodyRegion().front();
149 Block
*afterBody
= &allocaScopeOp
.getBodyRegion().back();
150 rewriter
.inlineRegionBefore(allocaScopeOp
.getBodyRegion(), continueBlock
);
152 // Save stack and then branch into the body of the region.
153 rewriter
.setInsertionPointToEnd(currentBlock
);
155 rewriter
.create
<LLVM::StackSaveOp
>(loc
, getVoidPtrType());
156 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange(), beforeBody
);
158 // Replace the alloca_scope return with a branch that jumps out of the body.
159 // Stack restore before leaving the body region.
160 rewriter
.setInsertionPointToEnd(afterBody
);
162 cast
<memref::AllocaScopeReturnOp
>(afterBody
->getTerminator());
163 auto branchOp
= rewriter
.replaceOpWithNewOp
<LLVM::BrOp
>(
164 returnOp
, returnOp
.getResults(), continueBlock
);
166 // Insert stack restore before jumping out the body of the region.
167 rewriter
.setInsertionPoint(branchOp
);
168 rewriter
.create
<LLVM::StackRestoreOp
>(loc
, stackSaveOp
);
170 // Replace the op with values return from the body region.
171 rewriter
.replaceOp(allocaScopeOp
, continueBlock
->getArguments());
177 struct AssumeAlignmentOpLowering
178 : public ConvertOpToLLVMPattern
<memref::AssumeAlignmentOp
> {
179 using ConvertOpToLLVMPattern
<
180 memref::AssumeAlignmentOp
>::ConvertOpToLLVMPattern
;
181 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter
&converter
)
182 : ConvertOpToLLVMPattern
<memref::AssumeAlignmentOp
>(converter
) {}
185 matchAndRewrite(memref::AssumeAlignmentOp op
, OpAdaptor adaptor
,
186 ConversionPatternRewriter
&rewriter
) const override
{
187 Value memref
= adaptor
.getMemref();
188 unsigned alignment
= op
.getAlignment();
189 auto loc
= op
.getLoc();
191 auto srcMemRefType
= cast
<MemRefType
>(op
.getMemref().getType());
192 Value ptr
= getStridedElementPtr(loc
, srcMemRefType
, memref
, /*indices=*/{},
195 // Emit llvm.assume(memref & (alignment - 1) == 0).
197 // This relies on LLVM's CSE optimization (potentially after SROA), since
198 // after CSE all memref instances should get de-duplicated into the same
199 // pointer SSA value.
200 MemRefDescriptor
memRefDescriptor(memref
);
202 getIntPtrType(memRefDescriptor
.getElementPtrType().getAddressSpace());
203 Value zero
= createIndexAttrConstant(rewriter
, loc
, intPtrType
, 0);
205 createIndexAttrConstant(rewriter
, loc
, intPtrType
, alignment
- 1);
206 Value ptrValue
= rewriter
.create
<LLVM::PtrToIntOp
>(loc
, intPtrType
, ptr
);
207 rewriter
.create
<LLVM::AssumeOp
>(
208 loc
, rewriter
.create
<LLVM::ICmpOp
>(
209 loc
, LLVM::ICmpPredicate::eq
,
210 rewriter
.create
<LLVM::AndOp
>(loc
, ptrValue
, mask
), zero
));
212 rewriter
.eraseOp(op
);
217 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
218 // The memref descriptor being an SSA value, there is no need to clean it up
220 struct DeallocOpLowering
: public ConvertOpToLLVMPattern
<memref::DeallocOp
> {
221 using ConvertOpToLLVMPattern
<memref::DeallocOp
>::ConvertOpToLLVMPattern
;
223 explicit DeallocOpLowering(const LLVMTypeConverter
&converter
)
224 : ConvertOpToLLVMPattern
<memref::DeallocOp
>(converter
) {}
227 matchAndRewrite(memref::DeallocOp op
, OpAdaptor adaptor
,
228 ConversionPatternRewriter
&rewriter
) const override
{
229 // Insert the `free` declaration if it is not already present.
230 LLVM::LLVMFuncOp freeFunc
=
231 getFreeFn(getTypeConverter(), op
->getParentOfType
<ModuleOp
>());
233 if (auto unrankedTy
=
234 llvm::dyn_cast
<UnrankedMemRefType
>(op
.getMemref().getType())) {
235 auto elementPtrTy
= LLVM::LLVMPointerType::get(
236 rewriter
.getContext(), unrankedTy
.getMemorySpaceAsInt());
237 allocatedPtr
= UnrankedMemRefDescriptor::allocatedPtr(
238 rewriter
, op
.getLoc(),
239 UnrankedMemRefDescriptor(adaptor
.getMemref())
240 .memRefDescPtr(rewriter
, op
.getLoc()),
243 allocatedPtr
= MemRefDescriptor(adaptor
.getMemref())
244 .allocatedPtr(rewriter
, op
.getLoc());
246 rewriter
.replaceOpWithNewOp
<LLVM::CallOp
>(op
, freeFunc
, allocatedPtr
);
251 // A `dim` is converted to a constant for static sizes and to an access to the
252 // size stored in the memref descriptor for dynamic sizes.
253 struct DimOpLowering
: public ConvertOpToLLVMPattern
<memref::DimOp
> {
254 using ConvertOpToLLVMPattern
<memref::DimOp
>::ConvertOpToLLVMPattern
;
257 matchAndRewrite(memref::DimOp dimOp
, OpAdaptor adaptor
,
258 ConversionPatternRewriter
&rewriter
) const override
{
259 Type operandType
= dimOp
.getSource().getType();
260 if (isa
<UnrankedMemRefType
>(operandType
)) {
261 FailureOr
<Value
> extractedSize
= extractSizeOfUnrankedMemRef(
262 operandType
, dimOp
, adaptor
.getOperands(), rewriter
);
263 if (failed(extractedSize
))
265 rewriter
.replaceOp(dimOp
, {*extractedSize
});
268 if (isa
<MemRefType
>(operandType
)) {
270 dimOp
, {extractSizeOfRankedMemRef(operandType
, dimOp
,
271 adaptor
.getOperands(), rewriter
)});
274 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
279 extractSizeOfUnrankedMemRef(Type operandType
, memref::DimOp dimOp
,
281 ConversionPatternRewriter
&rewriter
) const {
282 Location loc
= dimOp
.getLoc();
284 auto unrankedMemRefType
= cast
<UnrankedMemRefType
>(operandType
);
285 auto scalarMemRefType
=
286 MemRefType::get({}, unrankedMemRefType
.getElementType());
287 FailureOr
<unsigned> maybeAddressSpace
=
288 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType
);
289 if (failed(maybeAddressSpace
)) {
290 dimOp
.emitOpError("memref memory space must be convertible to an integer "
294 unsigned addressSpace
= *maybeAddressSpace
;
296 // Extract pointer to the underlying ranked descriptor and bitcast it to a
297 // memref<element_type> descriptor pointer to minimize the number of GEP
299 UnrankedMemRefDescriptor
unrankedDesc(adaptor
.getSource());
300 Value underlyingRankedDesc
= unrankedDesc
.memRefDescPtr(rewriter
, loc
);
302 Type elementType
= typeConverter
->convertType(scalarMemRefType
);
304 // Get pointer to offset field of memref<element_type> descriptor.
306 LLVM::LLVMPointerType::get(rewriter
.getContext(), addressSpace
);
307 Value offsetPtr
= rewriter
.create
<LLVM::GEPOp
>(
308 loc
, indexPtrTy
, elementType
, underlyingRankedDesc
,
309 ArrayRef
<LLVM::GEPArg
>{0, 2});
311 // The size value that we have to extract can be obtained using GEPop with
312 // `dimOp.index() + 1` index argument.
313 Value idxPlusOne
= rewriter
.create
<LLVM::AddOp
>(
314 loc
, createIndexAttrConstant(rewriter
, loc
, getIndexType(), 1),
316 Value sizePtr
= rewriter
.create
<LLVM::GEPOp
>(
317 loc
, indexPtrTy
, getTypeConverter()->getIndexType(), offsetPtr
,
320 .create
<LLVM::LoadOp
>(loc
, getTypeConverter()->getIndexType(), sizePtr
)
324 std::optional
<int64_t> getConstantDimIndex(memref::DimOp dimOp
) const {
325 if (auto idx
= dimOp
.getConstantIndex())
328 if (auto constantOp
= dimOp
.getIndex().getDefiningOp
<LLVM::ConstantOp
>())
329 return cast
<IntegerAttr
>(constantOp
.getValue()).getValue().getSExtValue();
334 Value
extractSizeOfRankedMemRef(Type operandType
, memref::DimOp dimOp
,
336 ConversionPatternRewriter
&rewriter
) const {
337 Location loc
= dimOp
.getLoc();
339 // Take advantage if index is constant.
340 MemRefType memRefType
= cast
<MemRefType
>(operandType
);
341 Type indexType
= getIndexType();
342 if (std::optional
<int64_t> index
= getConstantDimIndex(dimOp
)) {
344 if (i
>= 0 && i
< memRefType
.getRank()) {
345 if (memRefType
.isDynamicDim(i
)) {
346 // extract dynamic size from the memref descriptor.
347 MemRefDescriptor
descriptor(adaptor
.getSource());
348 return descriptor
.size(rewriter
, loc
, i
);
350 // Use constant for static size.
351 int64_t dimSize
= memRefType
.getDimSize(i
);
352 return createIndexAttrConstant(rewriter
, loc
, indexType
, dimSize
);
355 Value index
= adaptor
.getIndex();
356 int64_t rank
= memRefType
.getRank();
357 MemRefDescriptor
memrefDescriptor(adaptor
.getSource());
358 return memrefDescriptor
.size(rewriter
, loc
, index
, rank
);
362 /// Common base for load and store operations on MemRefs. Restricts the match
363 /// to supported MemRef types. Provides functionality to emit code accessing a
364 /// specific element of the underlying data buffer.
365 template <typename Derived
>
366 struct LoadStoreOpLowering
: public ConvertOpToLLVMPattern
<Derived
> {
367 using ConvertOpToLLVMPattern
<Derived
>::ConvertOpToLLVMPattern
;
368 using ConvertOpToLLVMPattern
<Derived
>::isConvertibleAndHasIdentityMaps
;
369 using Base
= LoadStoreOpLowering
<Derived
>;
371 LogicalResult
match(Derived op
) const override
{
372 MemRefType type
= op
.getMemRefType();
373 return isConvertibleAndHasIdentityMaps(type
) ? success() : failure();
377 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
378 /// retried until it succeeds in atomically storing a new value into memory.
380 /// +---------------------------------+
381 /// | <code before the AtomicRMWOp> |
382 /// | <compute initial %loaded> |
383 /// | cf.br loop(%loaded) |
384 /// +---------------------------------+
388 /// | +--------------------------------+
389 /// | | loop(%loaded): |
390 /// | | <body contents> |
391 /// | | %pair = cmpxchg |
392 /// | | %ok = %pair[0] |
393 /// | | %new = %pair[1] |
394 /// | | cf.cond_br %ok, end, loop(%new) |
395 /// | +--------------------------------+
399 /// +--------------------------------+
401 /// | <code after the AtomicRMWOp> |
402 /// +--------------------------------+
404 struct GenericAtomicRMWOpLowering
405 : public LoadStoreOpLowering
<memref::GenericAtomicRMWOp
> {
409 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp
, OpAdaptor adaptor
,
410 ConversionPatternRewriter
&rewriter
) const override
{
411 auto loc
= atomicOp
.getLoc();
412 Type valueType
= typeConverter
->convertType(atomicOp
.getResult().getType());
414 // Split the block into initial, loop, and ending parts.
415 auto *initBlock
= rewriter
.getInsertionBlock();
416 auto *loopBlock
= rewriter
.splitBlock(initBlock
, Block::iterator(atomicOp
));
417 loopBlock
->addArgument(valueType
, loc
);
420 rewriter
.splitBlock(loopBlock
, Block::iterator(atomicOp
)++);
422 // Compute the loaded value and branch to the loop block.
423 rewriter
.setInsertionPointToEnd(initBlock
);
424 auto memRefType
= cast
<MemRefType
>(atomicOp
.getMemref().getType());
425 auto dataPtr
= getStridedElementPtr(loc
, memRefType
, adaptor
.getMemref(),
426 adaptor
.getIndices(), rewriter
);
427 Value init
= rewriter
.create
<LLVM::LoadOp
>(
428 loc
, typeConverter
->convertType(memRefType
.getElementType()), dataPtr
);
429 rewriter
.create
<LLVM::BrOp
>(loc
, init
, loopBlock
);
431 // Prepare the body of the loop block.
432 rewriter
.setInsertionPointToStart(loopBlock
);
434 // Clone the GenericAtomicRMWOp region and extract the result.
435 auto loopArgument
= loopBlock
->getArgument(0);
437 mapping
.map(atomicOp
.getCurrentValue(), loopArgument
);
438 Block
&entryBlock
= atomicOp
.body().front();
439 for (auto &nestedOp
: entryBlock
.without_terminator()) {
440 Operation
*clone
= rewriter
.clone(nestedOp
, mapping
);
441 mapping
.map(nestedOp
.getResults(), clone
->getResults());
443 Value result
= mapping
.lookup(entryBlock
.getTerminator()->getOperand(0));
445 // Prepare the epilog of the loop block.
446 // Append the cmpxchg op to the end of the loop block.
447 auto successOrdering
= LLVM::AtomicOrdering::acq_rel
;
448 auto failureOrdering
= LLVM::AtomicOrdering::monotonic
;
449 auto cmpxchg
= rewriter
.create
<LLVM::AtomicCmpXchgOp
>(
450 loc
, dataPtr
, loopArgument
, result
, successOrdering
, failureOrdering
);
451 // Extract the %new_loaded and %ok values from the pair.
452 Value newLoaded
= rewriter
.create
<LLVM::ExtractValueOp
>(loc
, cmpxchg
, 0);
453 Value ok
= rewriter
.create
<LLVM::ExtractValueOp
>(loc
, cmpxchg
, 1);
455 // Conditionally branch to the end or back to the loop depending on %ok.
456 rewriter
.create
<LLVM::CondBrOp
>(loc
, ok
, endBlock
, ArrayRef
<Value
>(),
457 loopBlock
, newLoaded
);
459 rewriter
.setInsertionPointToEnd(endBlock
);
461 // The 'result' of the atomic_rmw op is the newly loaded value.
462 rewriter
.replaceOp(atomicOp
, {newLoaded
});
468 /// Returns the LLVM type of the global variable given the memref type `type`.
470 convertGlobalMemrefTypeToLLVM(MemRefType type
,
471 const LLVMTypeConverter
&typeConverter
) {
472 // LLVM type for a global memref will be a multi-dimension array. For
473 // declarations or uninitialized global memrefs, we can potentially flatten
474 // this to a 1D array. However, for memref.global's with an initial value,
475 // we do not intend to flatten the ElementsAttribute when going from std ->
476 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
477 Type elementType
= typeConverter
.convertType(type
.getElementType());
478 Type arrayTy
= elementType
;
479 // Shape has the outermost dim at index 0, so need to walk it backwards
480 for (int64_t dim
: llvm::reverse(type
.getShape()))
481 arrayTy
= LLVM::LLVMArrayType::get(arrayTy
, dim
);
485 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
486 struct GlobalMemrefOpLowering
487 : public ConvertOpToLLVMPattern
<memref::GlobalOp
> {
488 using ConvertOpToLLVMPattern
<memref::GlobalOp
>::ConvertOpToLLVMPattern
;
491 matchAndRewrite(memref::GlobalOp global
, OpAdaptor adaptor
,
492 ConversionPatternRewriter
&rewriter
) const override
{
493 MemRefType type
= global
.getType();
494 if (!isConvertibleAndHasIdentityMaps(type
))
497 Type arrayTy
= convertGlobalMemrefTypeToLLVM(type
, *getTypeConverter());
499 LLVM::Linkage linkage
=
500 global
.isPublic() ? LLVM::Linkage::External
: LLVM::Linkage::Private
;
502 Attribute initialValue
= nullptr;
503 if (!global
.isExternal() && !global
.isUninitialized()) {
504 auto elementsAttr
= llvm::cast
<ElementsAttr
>(*global
.getInitialValue());
505 initialValue
= elementsAttr
;
507 // For scalar memrefs, the global variable created is of the element type,
508 // so unpack the elements attribute to extract the value.
509 if (type
.getRank() == 0)
510 initialValue
= elementsAttr
.getSplatValue
<Attribute
>();
513 uint64_t alignment
= global
.getAlignment().value_or(0);
514 FailureOr
<unsigned> addressSpace
=
515 getTypeConverter()->getMemRefAddressSpace(type
);
516 if (failed(addressSpace
))
517 return global
.emitOpError(
518 "memory space cannot be converted to an integer address space");
519 auto newGlobal
= rewriter
.replaceOpWithNewOp
<LLVM::GlobalOp
>(
520 global
, arrayTy
, global
.getConstant(), linkage
, global
.getSymName(),
521 initialValue
, alignment
, *addressSpace
);
522 if (!global
.isExternal() && global
.isUninitialized()) {
523 rewriter
.createBlock(&newGlobal
.getInitializerRegion());
525 rewriter
.create
<LLVM::UndefOp
>(global
.getLoc(), arrayTy
)};
526 rewriter
.create
<LLVM::ReturnOp
>(global
.getLoc(), undef
);
532 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
533 /// the first element stashed into the descriptor. This reuses
534 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
535 struct GetGlobalMemrefOpLowering
: public AllocLikeOpLLVMLowering
{
536 GetGlobalMemrefOpLowering(const LLVMTypeConverter
&converter
)
537 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
540 /// Buffer "allocation" for memref.get_global op is getting the address of
541 /// the global variable referenced.
542 std::tuple
<Value
, Value
> allocateBuffer(ConversionPatternRewriter
&rewriter
,
543 Location loc
, Value sizeBytes
,
544 Operation
*op
) const override
{
545 auto getGlobalOp
= cast
<memref::GetGlobalOp
>(op
);
546 MemRefType type
= cast
<MemRefType
>(getGlobalOp
.getResult().getType());
548 // This is called after a type conversion, which would have failed if this
550 FailureOr
<unsigned> maybeAddressSpace
=
551 getTypeConverter()->getMemRefAddressSpace(type
);
552 if (failed(maybeAddressSpace
))
553 return std::make_tuple(Value(), Value());
554 unsigned memSpace
= *maybeAddressSpace
;
556 Type arrayTy
= convertGlobalMemrefTypeToLLVM(type
, *getTypeConverter());
557 auto ptrTy
= LLVM::LLVMPointerType::get(rewriter
.getContext(), memSpace
);
559 rewriter
.create
<LLVM::AddressOfOp
>(loc
, ptrTy
, getGlobalOp
.getName());
561 // Get the address of the first element in the array by creating a GEP with
562 // the address of the GV as the base, and (rank + 1) number of 0 indices.
563 auto gep
= rewriter
.create
<LLVM::GEPOp
>(
564 loc
, ptrTy
, arrayTy
, addressOf
,
565 SmallVector
<LLVM::GEPArg
>(type
.getRank() + 1, 0));
567 // We do not expect the memref obtained using `memref.get_global` to be
568 // ever deallocated. Set the allocated pointer to be known bad value to
569 // help debug if that ever happens.
570 auto intPtrType
= getIntPtrType(memSpace
);
571 Value deadBeefConst
=
572 createIndexAttrConstant(rewriter
, op
->getLoc(), intPtrType
, 0xdeadbeef);
574 rewriter
.create
<LLVM::IntToPtrOp
>(loc
, ptrTy
, deadBeefConst
);
576 // Both allocated and aligned pointers are same. We could potentially stash
577 // a nullptr for the allocated pointer since we do not expect any dealloc.
578 return std::make_tuple(deadBeefPtr
, gep
);
582 // Load operation is lowered to obtaining a pointer to the indexed element
584 struct LoadOpLowering
: public LoadStoreOpLowering
<memref::LoadOp
> {
588 matchAndRewrite(memref::LoadOp loadOp
, OpAdaptor adaptor
,
589 ConversionPatternRewriter
&rewriter
) const override
{
590 auto type
= loadOp
.getMemRefType();
593 getStridedElementPtr(loadOp
.getLoc(), type
, adaptor
.getMemref(),
594 adaptor
.getIndices(), rewriter
);
595 rewriter
.replaceOpWithNewOp
<LLVM::LoadOp
>(
596 loadOp
, typeConverter
->convertType(type
.getElementType()), dataPtr
, 0,
597 false, loadOp
.getNontemporal());
602 // Store operation is lowered to obtaining a pointer to the indexed element,
603 // and storing the given value to it.
604 struct StoreOpLowering
: public LoadStoreOpLowering
<memref::StoreOp
> {
608 matchAndRewrite(memref::StoreOp op
, OpAdaptor adaptor
,
609 ConversionPatternRewriter
&rewriter
) const override
{
610 auto type
= op
.getMemRefType();
612 Value dataPtr
= getStridedElementPtr(op
.getLoc(), type
, adaptor
.getMemref(),
613 adaptor
.getIndices(), rewriter
);
614 rewriter
.replaceOpWithNewOp
<LLVM::StoreOp
>(op
, adaptor
.getValue(), dataPtr
,
615 0, false, op
.getNontemporal());
620 // The prefetch operation is lowered in a way similar to the load operation
621 // except that the llvm.prefetch operation is used for replacement.
622 struct PrefetchOpLowering
: public LoadStoreOpLowering
<memref::PrefetchOp
> {
626 matchAndRewrite(memref::PrefetchOp prefetchOp
, OpAdaptor adaptor
,
627 ConversionPatternRewriter
&rewriter
) const override
{
628 auto type
= prefetchOp
.getMemRefType();
629 auto loc
= prefetchOp
.getLoc();
631 Value dataPtr
= getStridedElementPtr(loc
, type
, adaptor
.getMemref(),
632 adaptor
.getIndices(), rewriter
);
634 // Replace with llvm.prefetch.
635 IntegerAttr isWrite
= rewriter
.getI32IntegerAttr(prefetchOp
.getIsWrite());
636 IntegerAttr localityHint
= prefetchOp
.getLocalityHintAttr();
638 rewriter
.getI32IntegerAttr(prefetchOp
.getIsDataCache());
639 rewriter
.replaceOpWithNewOp
<LLVM::Prefetch
>(prefetchOp
, dataPtr
, isWrite
,
640 localityHint
, isData
);
645 struct RankOpLowering
: public ConvertOpToLLVMPattern
<memref::RankOp
> {
646 using ConvertOpToLLVMPattern
<memref::RankOp
>::ConvertOpToLLVMPattern
;
649 matchAndRewrite(memref::RankOp op
, OpAdaptor adaptor
,
650 ConversionPatternRewriter
&rewriter
) const override
{
651 Location loc
= op
.getLoc();
652 Type operandType
= op
.getMemref().getType();
653 if (dyn_cast
<UnrankedMemRefType
>(operandType
)) {
654 UnrankedMemRefDescriptor
desc(adaptor
.getMemref());
655 rewriter
.replaceOp(op
, {desc
.rank(rewriter
, loc
)});
658 if (auto rankedMemRefType
= dyn_cast
<MemRefType
>(operandType
)) {
659 Type indexType
= getIndexType();
660 rewriter
.replaceOp(op
,
661 {createIndexAttrConstant(rewriter
, loc
, indexType
,
662 rankedMemRefType
.getRank())});
669 struct MemRefCastOpLowering
: public ConvertOpToLLVMPattern
<memref::CastOp
> {
670 using ConvertOpToLLVMPattern
<memref::CastOp
>::ConvertOpToLLVMPattern
;
672 LogicalResult
match(memref::CastOp memRefCastOp
) const override
{
673 Type srcType
= memRefCastOp
.getOperand().getType();
674 Type dstType
= memRefCastOp
.getType();
676 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
677 // used for type erasure. For now they must preserve underlying element type
678 // and require source and result type to have the same rank. Therefore,
679 // perform a sanity check that the underlying structs are the same. Once op
680 // semantics are relaxed we can revisit.
681 if (isa
<MemRefType
>(srcType
) && isa
<MemRefType
>(dstType
))
682 return success(typeConverter
->convertType(srcType
) ==
683 typeConverter
->convertType(dstType
));
685 // At least one of the operands is unranked type
686 assert(isa
<UnrankedMemRefType
>(srcType
) ||
687 isa
<UnrankedMemRefType
>(dstType
));
689 // Unranked to unranked cast is disallowed
690 return !(isa
<UnrankedMemRefType
>(srcType
) &&
691 isa
<UnrankedMemRefType
>(dstType
))
696 void rewrite(memref::CastOp memRefCastOp
, OpAdaptor adaptor
,
697 ConversionPatternRewriter
&rewriter
) const override
{
698 auto srcType
= memRefCastOp
.getOperand().getType();
699 auto dstType
= memRefCastOp
.getType();
700 auto targetStructType
= typeConverter
->convertType(memRefCastOp
.getType());
701 auto loc
= memRefCastOp
.getLoc();
703 // For ranked/ranked case, just keep the original descriptor.
704 if (isa
<MemRefType
>(srcType
) && isa
<MemRefType
>(dstType
))
705 return rewriter
.replaceOp(memRefCastOp
, {adaptor
.getSource()});
707 if (isa
<MemRefType
>(srcType
) && isa
<UnrankedMemRefType
>(dstType
)) {
708 // Casting ranked to unranked memref type
709 // Set the rank in the destination from the memref type
710 // Allocate space on the stack and copy the src memref descriptor
711 // Set the ptr in the destination to the stack space
712 auto srcMemRefType
= cast
<MemRefType
>(srcType
);
713 int64_t rank
= srcMemRefType
.getRank();
714 // ptr = AllocaOp sizeof(MemRefDescriptor)
715 auto ptr
= getTypeConverter()->promoteOneMemRefDescriptor(
716 loc
, adaptor
.getSource(), rewriter
);
718 // rank = ConstantOp srcRank
719 auto rankVal
= rewriter
.create
<LLVM::ConstantOp
>(
720 loc
, getIndexType(), rewriter
.getIndexAttr(rank
));
722 UnrankedMemRefDescriptor memRefDesc
=
723 UnrankedMemRefDescriptor::undef(rewriter
, loc
, targetStructType
);
724 // d1 = InsertValueOp undef, rank, 0
725 memRefDesc
.setRank(rewriter
, loc
, rankVal
);
726 // d2 = InsertValueOp d1, ptr, 1
727 memRefDesc
.setMemRefDescPtr(rewriter
, loc
, ptr
);
728 rewriter
.replaceOp(memRefCastOp
, (Value
)memRefDesc
);
730 } else if (isa
<UnrankedMemRefType
>(srcType
) && isa
<MemRefType
>(dstType
)) {
731 // Casting from unranked type to ranked.
732 // The operation is assumed to be doing a correct cast. If the destination
733 // type mismatches the unranked the type, it is undefined behavior.
734 UnrankedMemRefDescriptor
memRefDesc(adaptor
.getSource());
735 // ptr = ExtractValueOp src, 1
736 auto ptr
= memRefDesc
.memRefDescPtr(rewriter
, loc
);
738 // struct = LoadOp ptr
739 auto loadOp
= rewriter
.create
<LLVM::LoadOp
>(loc
, targetStructType
, ptr
);
740 rewriter
.replaceOp(memRefCastOp
, loadOp
.getResult());
742 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
747 /// Pattern to lower a `memref.copy` to llvm.
749 /// For memrefs with identity layouts, the copy is lowered to the llvm
750 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
751 /// to the generic `MemrefCopyFn`.
752 struct MemRefCopyOpLowering
: public ConvertOpToLLVMPattern
<memref::CopyOp
> {
753 using ConvertOpToLLVMPattern
<memref::CopyOp
>::ConvertOpToLLVMPattern
;
756 lowerToMemCopyIntrinsic(memref::CopyOp op
, OpAdaptor adaptor
,
757 ConversionPatternRewriter
&rewriter
) const {
758 auto loc
= op
.getLoc();
759 auto srcType
= dyn_cast
<MemRefType
>(op
.getSource().getType());
761 MemRefDescriptor
srcDesc(adaptor
.getSource());
763 // Compute number of elements.
764 Value numElements
= rewriter
.create
<LLVM::ConstantOp
>(
765 loc
, getIndexType(), rewriter
.getIndexAttr(1));
766 for (int pos
= 0; pos
< srcType
.getRank(); ++pos
) {
767 auto size
= srcDesc
.size(rewriter
, loc
, pos
);
768 numElements
= rewriter
.create
<LLVM::MulOp
>(loc
, numElements
, size
);
772 auto sizeInBytes
= getSizeInBytes(loc
, srcType
.getElementType(), rewriter
);
775 rewriter
.create
<LLVM::MulOp
>(loc
, numElements
, sizeInBytes
);
777 Type elementType
= typeConverter
->convertType(srcType
.getElementType());
779 Value srcBasePtr
= srcDesc
.alignedPtr(rewriter
, loc
);
780 Value srcOffset
= srcDesc
.offset(rewriter
, loc
);
781 Value srcPtr
= rewriter
.create
<LLVM::GEPOp
>(
782 loc
, srcBasePtr
.getType(), elementType
, srcBasePtr
, srcOffset
);
783 MemRefDescriptor
targetDesc(adaptor
.getTarget());
784 Value targetBasePtr
= targetDesc
.alignedPtr(rewriter
, loc
);
785 Value targetOffset
= targetDesc
.offset(rewriter
, loc
);
786 Value targetPtr
= rewriter
.create
<LLVM::GEPOp
>(
787 loc
, targetBasePtr
.getType(), elementType
, targetBasePtr
, targetOffset
);
788 rewriter
.create
<LLVM::MemcpyOp
>(loc
, targetPtr
, srcPtr
, totalSize
,
789 /*isVolatile=*/false);
790 rewriter
.eraseOp(op
);
796 lowerToMemCopyFunctionCall(memref::CopyOp op
, OpAdaptor adaptor
,
797 ConversionPatternRewriter
&rewriter
) const {
798 auto loc
= op
.getLoc();
799 auto srcType
= cast
<BaseMemRefType
>(op
.getSource().getType());
800 auto targetType
= cast
<BaseMemRefType
>(op
.getTarget().getType());
802 // First make sure we have an unranked memref descriptor representation.
803 auto makeUnranked
= [&, this](Value ranked
, MemRefType type
) {
804 auto rank
= rewriter
.create
<LLVM::ConstantOp
>(loc
, getIndexType(),
806 auto *typeConverter
= getTypeConverter();
808 typeConverter
->promoteOneMemRefDescriptor(loc
, ranked
, rewriter
);
811 UnrankedMemRefType::get(type
.getElementType(), type
.getMemorySpace());
812 return UnrankedMemRefDescriptor::pack(
813 rewriter
, loc
, *typeConverter
, unrankedType
, ValueRange
{rank
, ptr
});
816 // Save stack position before promoting descriptors
818 rewriter
.create
<LLVM::StackSaveOp
>(loc
, getVoidPtrType());
820 auto srcMemRefType
= dyn_cast
<MemRefType
>(srcType
);
821 Value unrankedSource
=
822 srcMemRefType
? makeUnranked(adaptor
.getSource(), srcMemRefType
)
823 : adaptor
.getSource();
824 auto targetMemRefType
= dyn_cast
<MemRefType
>(targetType
);
825 Value unrankedTarget
=
826 targetMemRefType
? makeUnranked(adaptor
.getTarget(), targetMemRefType
)
827 : adaptor
.getTarget();
829 // Now promote the unranked descriptors to the stack.
830 auto one
= rewriter
.create
<LLVM::ConstantOp
>(loc
, getIndexType(),
831 rewriter
.getIndexAttr(1));
832 auto promote
= [&](Value desc
) {
833 auto ptrType
= LLVM::LLVMPointerType::get(rewriter
.getContext());
835 rewriter
.create
<LLVM::AllocaOp
>(loc
, ptrType
, desc
.getType(), one
);
836 rewriter
.create
<LLVM::StoreOp
>(loc
, desc
, allocated
);
840 auto sourcePtr
= promote(unrankedSource
);
841 auto targetPtr
= promote(unrankedTarget
);
843 // Derive size from llvm.getelementptr which will account for any
844 // potential alignment
845 auto elemSize
= getSizeInBytes(loc
, srcType
.getElementType(), rewriter
);
846 auto copyFn
= LLVM::lookupOrCreateMemRefCopyFn(
847 op
->getParentOfType
<ModuleOp
>(), getIndexType(), sourcePtr
.getType());
848 rewriter
.create
<LLVM::CallOp
>(loc
, copyFn
,
849 ValueRange
{elemSize
, sourcePtr
, targetPtr
});
851 // Restore stack used for descriptors
852 rewriter
.create
<LLVM::StackRestoreOp
>(loc
, stackSaveOp
);
854 rewriter
.eraseOp(op
);
860 matchAndRewrite(memref::CopyOp op
, OpAdaptor adaptor
,
861 ConversionPatternRewriter
&rewriter
) const override
{
862 auto srcType
= cast
<BaseMemRefType
>(op
.getSource().getType());
863 auto targetType
= cast
<BaseMemRefType
>(op
.getTarget().getType());
865 auto isContiguousMemrefType
= [&](BaseMemRefType type
) {
866 auto memrefType
= dyn_cast
<mlir::MemRefType
>(type
);
867 // We can use memcpy for memrefs if they have an identity layout or are
868 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
869 // special case handled by memrefCopy.
871 (memrefType
.getLayout().isIdentity() ||
872 (memrefType
.hasStaticShape() && memrefType
.getNumElements() > 0 &&
873 memref::isStaticShapeAndContiguousRowMajor(memrefType
)));
876 if (isContiguousMemrefType(srcType
) && isContiguousMemrefType(targetType
))
877 return lowerToMemCopyIntrinsic(op
, adaptor
, rewriter
);
879 return lowerToMemCopyFunctionCall(op
, adaptor
, rewriter
);
883 struct MemorySpaceCastOpLowering
884 : public ConvertOpToLLVMPattern
<memref::MemorySpaceCastOp
> {
885 using ConvertOpToLLVMPattern
<
886 memref::MemorySpaceCastOp
>::ConvertOpToLLVMPattern
;
889 matchAndRewrite(memref::MemorySpaceCastOp op
, OpAdaptor adaptor
,
890 ConversionPatternRewriter
&rewriter
) const override
{
891 Location loc
= op
.getLoc();
893 Type resultType
= op
.getDest().getType();
894 if (auto resultTypeR
= dyn_cast
<MemRefType
>(resultType
)) {
895 auto resultDescType
=
896 cast
<LLVM::LLVMStructType
>(typeConverter
->convertType(resultTypeR
));
897 Type newPtrType
= resultDescType
.getBody()[0];
899 SmallVector
<Value
> descVals
;
900 MemRefDescriptor::unpack(rewriter
, loc
, adaptor
.getSource(), resultTypeR
,
903 rewriter
.create
<LLVM::AddrSpaceCastOp
>(loc
, newPtrType
, descVals
[0]);
905 rewriter
.create
<LLVM::AddrSpaceCastOp
>(loc
, newPtrType
, descVals
[1]);
906 Value result
= MemRefDescriptor::pack(rewriter
, loc
, *getTypeConverter(),
907 resultTypeR
, descVals
);
908 rewriter
.replaceOp(op
, result
);
911 if (auto resultTypeU
= dyn_cast
<UnrankedMemRefType
>(resultType
)) {
912 // Since the type converter won't be doing this for us, get the address
914 auto sourceType
= cast
<UnrankedMemRefType
>(op
.getSource().getType());
915 FailureOr
<unsigned> maybeSourceAddrSpace
=
916 getTypeConverter()->getMemRefAddressSpace(sourceType
);
917 if (failed(maybeSourceAddrSpace
))
918 return rewriter
.notifyMatchFailure(loc
,
919 "non-integer source address space");
920 unsigned sourceAddrSpace
= *maybeSourceAddrSpace
;
921 FailureOr
<unsigned> maybeResultAddrSpace
=
922 getTypeConverter()->getMemRefAddressSpace(resultTypeU
);
923 if (failed(maybeResultAddrSpace
))
924 return rewriter
.notifyMatchFailure(loc
,
925 "non-integer result address space");
926 unsigned resultAddrSpace
= *maybeResultAddrSpace
;
928 UnrankedMemRefDescriptor
sourceDesc(adaptor
.getSource());
929 Value rank
= sourceDesc
.rank(rewriter
, loc
);
930 Value sourceUnderlyingDesc
= sourceDesc
.memRefDescPtr(rewriter
, loc
);
932 // Create and allocate storage for new memref descriptor.
933 auto result
= UnrankedMemRefDescriptor::undef(
934 rewriter
, loc
, typeConverter
->convertType(resultTypeU
));
935 result
.setRank(rewriter
, loc
, rank
);
936 SmallVector
<Value
, 1> sizes
;
937 UnrankedMemRefDescriptor::computeSizes(rewriter
, loc
, *getTypeConverter(),
938 result
, resultAddrSpace
, sizes
);
939 Value resultUnderlyingSize
= sizes
.front();
940 Value resultUnderlyingDesc
= rewriter
.create
<LLVM::AllocaOp
>(
941 loc
, getVoidPtrType(), rewriter
.getI8Type(), resultUnderlyingSize
);
942 result
.setMemRefDescPtr(rewriter
, loc
, resultUnderlyingDesc
);
944 // Copy pointers, performing address space casts.
945 auto sourceElemPtrType
=
946 LLVM::LLVMPointerType::get(rewriter
.getContext(), sourceAddrSpace
);
947 auto resultElemPtrType
=
948 LLVM::LLVMPointerType::get(rewriter
.getContext(), resultAddrSpace
);
950 Value allocatedPtr
= sourceDesc
.allocatedPtr(
951 rewriter
, loc
, sourceUnderlyingDesc
, sourceElemPtrType
);
953 sourceDesc
.alignedPtr(rewriter
, loc
, *getTypeConverter(),
954 sourceUnderlyingDesc
, sourceElemPtrType
);
955 allocatedPtr
= rewriter
.create
<LLVM::AddrSpaceCastOp
>(
956 loc
, resultElemPtrType
, allocatedPtr
);
957 alignedPtr
= rewriter
.create
<LLVM::AddrSpaceCastOp
>(
958 loc
, resultElemPtrType
, alignedPtr
);
960 result
.setAllocatedPtr(rewriter
, loc
, resultUnderlyingDesc
,
961 resultElemPtrType
, allocatedPtr
);
962 result
.setAlignedPtr(rewriter
, loc
, *getTypeConverter(),
963 resultUnderlyingDesc
, resultElemPtrType
, alignedPtr
);
965 // Copy all the index-valued operands.
966 Value sourceIndexVals
=
967 sourceDesc
.offsetBasePtr(rewriter
, loc
, *getTypeConverter(),
968 sourceUnderlyingDesc
, sourceElemPtrType
);
969 Value resultIndexVals
=
970 result
.offsetBasePtr(rewriter
, loc
, *getTypeConverter(),
971 resultUnderlyingDesc
, resultElemPtrType
);
973 int64_t bytesToSkip
=
974 2 * llvm::divideCeil(
975 getTypeConverter()->getPointerBitwidth(resultAddrSpace
), 8);
976 Value bytesToSkipConst
= rewriter
.create
<LLVM::ConstantOp
>(
977 loc
, getIndexType(), rewriter
.getIndexAttr(bytesToSkip
));
978 Value copySize
= rewriter
.create
<LLVM::SubOp
>(
979 loc
, getIndexType(), resultUnderlyingSize
, bytesToSkipConst
);
980 rewriter
.create
<LLVM::MemcpyOp
>(loc
, resultIndexVals
, sourceIndexVals
,
981 copySize
, /*isVolatile=*/false);
983 rewriter
.replaceOp(op
, ValueRange
{result
});
986 return rewriter
.notifyMatchFailure(loc
, "unexpected memref type");
990 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
991 /// memref type. In unranked case, the fields are extracted from the underlying
992 /// ranked descriptor.
993 static void extractPointersAndOffset(Location loc
,
994 ConversionPatternRewriter
&rewriter
,
995 const LLVMTypeConverter
&typeConverter
,
996 Value originalOperand
,
997 Value convertedOperand
,
998 Value
*allocatedPtr
, Value
*alignedPtr
,
999 Value
*offset
= nullptr) {
1000 Type operandType
= originalOperand
.getType();
1001 if (isa
<MemRefType
>(operandType
)) {
1002 MemRefDescriptor
desc(convertedOperand
);
1003 *allocatedPtr
= desc
.allocatedPtr(rewriter
, loc
);
1004 *alignedPtr
= desc
.alignedPtr(rewriter
, loc
);
1005 if (offset
!= nullptr)
1006 *offset
= desc
.offset(rewriter
, loc
);
1010 // These will all cause assert()s on unconvertible types.
1011 unsigned memorySpace
= *typeConverter
.getMemRefAddressSpace(
1012 cast
<UnrankedMemRefType
>(operandType
));
1013 auto elementPtrType
=
1014 LLVM::LLVMPointerType::get(rewriter
.getContext(), memorySpace
);
1016 // Extract pointer to the underlying ranked memref descriptor and cast it to
1018 UnrankedMemRefDescriptor
unrankedDesc(convertedOperand
);
1019 Value underlyingDescPtr
= unrankedDesc
.memRefDescPtr(rewriter
, loc
);
1021 *allocatedPtr
= UnrankedMemRefDescriptor::allocatedPtr(
1022 rewriter
, loc
, underlyingDescPtr
, elementPtrType
);
1023 *alignedPtr
= UnrankedMemRefDescriptor::alignedPtr(
1024 rewriter
, loc
, typeConverter
, underlyingDescPtr
, elementPtrType
);
1025 if (offset
!= nullptr) {
1026 *offset
= UnrankedMemRefDescriptor::offset(
1027 rewriter
, loc
, typeConverter
, underlyingDescPtr
, elementPtrType
);
1031 struct MemRefReinterpretCastOpLowering
1032 : public ConvertOpToLLVMPattern
<memref::ReinterpretCastOp
> {
1033 using ConvertOpToLLVMPattern
<
1034 memref::ReinterpretCastOp
>::ConvertOpToLLVMPattern
;
1037 matchAndRewrite(memref::ReinterpretCastOp castOp
, OpAdaptor adaptor
,
1038 ConversionPatternRewriter
&rewriter
) const override
{
1039 Type srcType
= castOp
.getSource().getType();
1042 if (failed(convertSourceMemRefToDescriptor(rewriter
, srcType
, castOp
,
1043 adaptor
, &descriptor
)))
1045 rewriter
.replaceOp(castOp
, {descriptor
});
1050 LogicalResult
convertSourceMemRefToDescriptor(
1051 ConversionPatternRewriter
&rewriter
, Type srcType
,
1052 memref::ReinterpretCastOp castOp
,
1053 memref::ReinterpretCastOp::Adaptor adaptor
, Value
*descriptor
) const {
1054 MemRefType targetMemRefType
=
1055 cast
<MemRefType
>(castOp
.getResult().getType());
1056 auto llvmTargetDescriptorTy
= dyn_cast_or_null
<LLVM::LLVMStructType
>(
1057 typeConverter
->convertType(targetMemRefType
));
1058 if (!llvmTargetDescriptorTy
)
1061 // Create descriptor.
1062 Location loc
= castOp
.getLoc();
1063 auto desc
= MemRefDescriptor::undef(rewriter
, loc
, llvmTargetDescriptorTy
);
1065 // Set allocated and aligned pointers.
1066 Value allocatedPtr
, alignedPtr
;
1067 extractPointersAndOffset(loc
, rewriter
, *getTypeConverter(),
1068 castOp
.getSource(), adaptor
.getSource(),
1069 &allocatedPtr
, &alignedPtr
);
1070 desc
.setAllocatedPtr(rewriter
, loc
, allocatedPtr
);
1071 desc
.setAlignedPtr(rewriter
, loc
, alignedPtr
);
1074 if (castOp
.isDynamicOffset(0))
1075 desc
.setOffset(rewriter
, loc
, adaptor
.getOffsets()[0]);
1077 desc
.setConstantOffset(rewriter
, loc
, castOp
.getStaticOffset(0));
1079 // Set sizes and strides.
1080 unsigned dynSizeId
= 0;
1081 unsigned dynStrideId
= 0;
1082 for (unsigned i
= 0, e
= targetMemRefType
.getRank(); i
< e
; ++i
) {
1083 if (castOp
.isDynamicSize(i
))
1084 desc
.setSize(rewriter
, loc
, i
, adaptor
.getSizes()[dynSizeId
++]);
1086 desc
.setConstantSize(rewriter
, loc
, i
, castOp
.getStaticSize(i
));
1088 if (castOp
.isDynamicStride(i
))
1089 desc
.setStride(rewriter
, loc
, i
, adaptor
.getStrides()[dynStrideId
++]);
1091 desc
.setConstantStride(rewriter
, loc
, i
, castOp
.getStaticStride(i
));
1098 struct MemRefReshapeOpLowering
1099 : public ConvertOpToLLVMPattern
<memref::ReshapeOp
> {
1100 using ConvertOpToLLVMPattern
<memref::ReshapeOp
>::ConvertOpToLLVMPattern
;
1103 matchAndRewrite(memref::ReshapeOp reshapeOp
, OpAdaptor adaptor
,
1104 ConversionPatternRewriter
&rewriter
) const override
{
1105 Type srcType
= reshapeOp
.getSource().getType();
1108 if (failed(convertSourceMemRefToDescriptor(rewriter
, srcType
, reshapeOp
,
1109 adaptor
, &descriptor
)))
1111 rewriter
.replaceOp(reshapeOp
, {descriptor
});
1117 convertSourceMemRefToDescriptor(ConversionPatternRewriter
&rewriter
,
1118 Type srcType
, memref::ReshapeOp reshapeOp
,
1119 memref::ReshapeOp::Adaptor adaptor
,
1120 Value
*descriptor
) const {
1121 auto shapeMemRefType
= cast
<MemRefType
>(reshapeOp
.getShape().getType());
1122 if (shapeMemRefType
.hasStaticShape()) {
1123 MemRefType targetMemRefType
=
1124 cast
<MemRefType
>(reshapeOp
.getResult().getType());
1125 auto llvmTargetDescriptorTy
= dyn_cast_or_null
<LLVM::LLVMStructType
>(
1126 typeConverter
->convertType(targetMemRefType
));
1127 if (!llvmTargetDescriptorTy
)
1130 // Create descriptor.
1131 Location loc
= reshapeOp
.getLoc();
1133 MemRefDescriptor::undef(rewriter
, loc
, llvmTargetDescriptorTy
);
1135 // Set allocated and aligned pointers.
1136 Value allocatedPtr
, alignedPtr
;
1137 extractPointersAndOffset(loc
, rewriter
, *getTypeConverter(),
1138 reshapeOp
.getSource(), adaptor
.getSource(),
1139 &allocatedPtr
, &alignedPtr
);
1140 desc
.setAllocatedPtr(rewriter
, loc
, allocatedPtr
);
1141 desc
.setAlignedPtr(rewriter
, loc
, alignedPtr
);
1143 // Extract the offset and strides from the type.
1145 SmallVector
<int64_t> strides
;
1146 if (failed(getStridesAndOffset(targetMemRefType
, strides
, offset
)))
1147 return rewriter
.notifyMatchFailure(
1148 reshapeOp
, "failed to get stride and offset exprs");
1150 if (!isStaticStrideOrOffset(offset
))
1151 return rewriter
.notifyMatchFailure(reshapeOp
,
1152 "dynamic offset is unsupported");
1154 desc
.setConstantOffset(rewriter
, loc
, offset
);
1156 assert(targetMemRefType
.getLayout().isIdentity() &&
1157 "Identity layout map is a precondition of a valid reshape op");
1159 Type indexType
= getIndexType();
1160 Value stride
= nullptr;
1161 int64_t targetRank
= targetMemRefType
.getRank();
1162 for (auto i
: llvm::reverse(llvm::seq
<int64_t>(0, targetRank
))) {
1163 if (!ShapedType::isDynamic(strides
[i
])) {
1164 // If the stride for this dimension is dynamic, then use the product
1165 // of the sizes of the inner dimensions.
1167 createIndexAttrConstant(rewriter
, loc
, indexType
, strides
[i
]);
1168 } else if (!stride
) {
1169 // `stride` is null only in the first iteration of the loop. However,
1170 // since the target memref has an identity layout, we can safely set
1171 // the innermost stride to 1.
1172 stride
= createIndexAttrConstant(rewriter
, loc
, indexType
, 1);
1176 // If the size of this dimension is dynamic, then load it at runtime
1177 // from the shape operand.
1178 if (!targetMemRefType
.isDynamicDim(i
)) {
1179 dimSize
= createIndexAttrConstant(rewriter
, loc
, indexType
,
1180 targetMemRefType
.getDimSize(i
));
1182 Value shapeOp
= reshapeOp
.getShape();
1183 Value index
= createIndexAttrConstant(rewriter
, loc
, indexType
, i
);
1184 dimSize
= rewriter
.create
<memref::LoadOp
>(loc
, shapeOp
, index
);
1185 Type indexType
= getIndexType();
1186 if (dimSize
.getType() != indexType
)
1187 dimSize
= typeConverter
->materializeTargetConversion(
1188 rewriter
, loc
, indexType
, dimSize
);
1189 assert(dimSize
&& "Invalid memref element type");
1192 desc
.setSize(rewriter
, loc
, i
, dimSize
);
1193 desc
.setStride(rewriter
, loc
, i
, stride
);
1195 // Prepare the stride value for the next dimension.
1196 stride
= rewriter
.create
<LLVM::MulOp
>(loc
, stride
, dimSize
);
1203 // The shape is a rank-1 tensor with unknown length.
1204 Location loc
= reshapeOp
.getLoc();
1205 MemRefDescriptor
shapeDesc(adaptor
.getShape());
1206 Value resultRank
= shapeDesc
.size(rewriter
, loc
, 0);
1208 // Extract address space and element type.
1209 auto targetType
= cast
<UnrankedMemRefType
>(reshapeOp
.getResult().getType());
1210 unsigned addressSpace
=
1211 *getTypeConverter()->getMemRefAddressSpace(targetType
);
1213 // Create the unranked memref descriptor that holds the ranked one. The
1214 // inner descriptor is allocated on stack.
1215 auto targetDesc
= UnrankedMemRefDescriptor::undef(
1216 rewriter
, loc
, typeConverter
->convertType(targetType
));
1217 targetDesc
.setRank(rewriter
, loc
, resultRank
);
1218 SmallVector
<Value
, 4> sizes
;
1219 UnrankedMemRefDescriptor::computeSizes(rewriter
, loc
, *getTypeConverter(),
1220 targetDesc
, addressSpace
, sizes
);
1221 Value underlyingDescPtr
= rewriter
.create
<LLVM::AllocaOp
>(
1222 loc
, getVoidPtrType(), IntegerType::get(getContext(), 8),
1224 targetDesc
.setMemRefDescPtr(rewriter
, loc
, underlyingDescPtr
);
1226 // Extract pointers and offset from the source memref.
1227 Value allocatedPtr
, alignedPtr
, offset
;
1228 extractPointersAndOffset(loc
, rewriter
, *getTypeConverter(),
1229 reshapeOp
.getSource(), adaptor
.getSource(),
1230 &allocatedPtr
, &alignedPtr
, &offset
);
1232 // Set pointers and offset.
1233 auto elementPtrType
=
1234 LLVM::LLVMPointerType::get(rewriter
.getContext(), addressSpace
);
1236 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter
, loc
, underlyingDescPtr
,
1237 elementPtrType
, allocatedPtr
);
1238 UnrankedMemRefDescriptor::setAlignedPtr(rewriter
, loc
, *getTypeConverter(),
1239 underlyingDescPtr
, elementPtrType
,
1241 UnrankedMemRefDescriptor::setOffset(rewriter
, loc
, *getTypeConverter(),
1242 underlyingDescPtr
, elementPtrType
,
1245 // Use the offset pointer as base for further addressing. Copy over the new
1246 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1247 Value targetSizesBase
= UnrankedMemRefDescriptor::sizeBasePtr(
1248 rewriter
, loc
, *getTypeConverter(), underlyingDescPtr
, elementPtrType
);
1249 Value targetStridesBase
= UnrankedMemRefDescriptor::strideBasePtr(
1250 rewriter
, loc
, *getTypeConverter(), targetSizesBase
, resultRank
);
1251 Value shapeOperandPtr
= shapeDesc
.alignedPtr(rewriter
, loc
);
1252 Value oneIndex
= createIndexAttrConstant(rewriter
, loc
, getIndexType(), 1);
1253 Value resultRankMinusOne
=
1254 rewriter
.create
<LLVM::SubOp
>(loc
, resultRank
, oneIndex
);
1256 Block
*initBlock
= rewriter
.getInsertionBlock();
1257 Type indexType
= getTypeConverter()->getIndexType();
1258 Block::iterator remainingOpsIt
= std::next(rewriter
.getInsertionPoint());
1260 Block
*condBlock
= rewriter
.createBlock(initBlock
->getParent(), {},
1261 {indexType
, indexType
}, {loc
, loc
});
1263 // Move the remaining initBlock ops to condBlock.
1264 Block
*remainingBlock
= rewriter
.splitBlock(initBlock
, remainingOpsIt
);
1265 rewriter
.mergeBlocks(remainingBlock
, condBlock
, ValueRange());
1267 rewriter
.setInsertionPointToEnd(initBlock
);
1268 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange({resultRankMinusOne
, oneIndex
}),
1270 rewriter
.setInsertionPointToStart(condBlock
);
1271 Value indexArg
= condBlock
->getArgument(0);
1272 Value strideArg
= condBlock
->getArgument(1);
1274 Value zeroIndex
= createIndexAttrConstant(rewriter
, loc
, indexType
, 0);
1275 Value pred
= rewriter
.create
<LLVM::ICmpOp
>(
1276 loc
, IntegerType::get(rewriter
.getContext(), 1),
1277 LLVM::ICmpPredicate::sge
, indexArg
, zeroIndex
);
1280 rewriter
.splitBlock(condBlock
, rewriter
.getInsertionPoint());
1281 rewriter
.setInsertionPointToStart(bodyBlock
);
1283 // Copy size from shape to descriptor.
1284 auto llvmIndexPtrType
= LLVM::LLVMPointerType::get(rewriter
.getContext());
1285 Value sizeLoadGep
= rewriter
.create
<LLVM::GEPOp
>(
1286 loc
, llvmIndexPtrType
,
1287 typeConverter
->convertType(shapeMemRefType
.getElementType()),
1288 shapeOperandPtr
, indexArg
);
1289 Value size
= rewriter
.create
<LLVM::LoadOp
>(loc
, indexType
, sizeLoadGep
);
1290 UnrankedMemRefDescriptor::setSize(rewriter
, loc
, *getTypeConverter(),
1291 targetSizesBase
, indexArg
, size
);
1293 // Write stride value and compute next one.
1294 UnrankedMemRefDescriptor::setStride(rewriter
, loc
, *getTypeConverter(),
1295 targetStridesBase
, indexArg
, strideArg
);
1296 Value nextStride
= rewriter
.create
<LLVM::MulOp
>(loc
, strideArg
, size
);
1298 // Decrement loop counter and branch back.
1299 Value decrement
= rewriter
.create
<LLVM::SubOp
>(loc
, indexArg
, oneIndex
);
1300 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange({decrement
, nextStride
}),
1304 rewriter
.splitBlock(bodyBlock
, rewriter
.getInsertionPoint());
1306 // Hook up the cond exit to the remainder.
1307 rewriter
.setInsertionPointToEnd(condBlock
);
1308 rewriter
.create
<LLVM::CondBrOp
>(loc
, pred
, bodyBlock
, std::nullopt
,
1309 remainder
, std::nullopt
);
1311 // Reset position to beginning of new remainder block.
1312 rewriter
.setInsertionPointToStart(remainder
);
1314 *descriptor
= targetDesc
;
1319 /// RessociatingReshapeOp must be expanded before we reach this stage.
1320 /// Report that information.
1321 template <typename ReshapeOp
>
1322 class ReassociatingReshapeOpConversion
1323 : public ConvertOpToLLVMPattern
<ReshapeOp
> {
1325 using ConvertOpToLLVMPattern
<ReshapeOp
>::ConvertOpToLLVMPattern
;
1326 using ReshapeOpAdaptor
= typename
ReshapeOp::Adaptor
;
1329 matchAndRewrite(ReshapeOp reshapeOp
, typename
ReshapeOp::Adaptor adaptor
,
1330 ConversionPatternRewriter
&rewriter
) const override
{
1331 return rewriter
.notifyMatchFailure(
1333 "reassociation operations should have been expanded beforehand");
1337 /// Subviews must be expanded before we reach this stage.
1338 /// Report that information.
1339 struct SubViewOpLowering
: public ConvertOpToLLVMPattern
<memref::SubViewOp
> {
1340 using ConvertOpToLLVMPattern
<memref::SubViewOp
>::ConvertOpToLLVMPattern
;
1343 matchAndRewrite(memref::SubViewOp subViewOp
, OpAdaptor adaptor
,
1344 ConversionPatternRewriter
&rewriter
) const override
{
1345 return rewriter
.notifyMatchFailure(
1346 subViewOp
, "subview operations should have been expanded beforehand");
1350 /// Conversion pattern that transforms a transpose op into:
1351 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1352 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1353 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1354 /// and stride. Size and stride are permutations of the original values.
1355 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1356 /// The transpose op is replaced by the alloca'ed pointer.
1357 class TransposeOpLowering
: public ConvertOpToLLVMPattern
<memref::TransposeOp
> {
1359 using ConvertOpToLLVMPattern
<memref::TransposeOp
>::ConvertOpToLLVMPattern
;
1362 matchAndRewrite(memref::TransposeOp transposeOp
, OpAdaptor adaptor
,
1363 ConversionPatternRewriter
&rewriter
) const override
{
1364 auto loc
= transposeOp
.getLoc();
1365 MemRefDescriptor
viewMemRef(adaptor
.getIn());
1367 // No permutation, early exit.
1368 if (transposeOp
.getPermutation().isIdentity())
1369 return rewriter
.replaceOp(transposeOp
, {viewMemRef
}), success();
1371 auto targetMemRef
= MemRefDescriptor::undef(
1373 typeConverter
->convertType(transposeOp
.getIn().getType()));
1375 // Copy the base and aligned pointers from the old descriptor to the new
1377 targetMemRef
.setAllocatedPtr(rewriter
, loc
,
1378 viewMemRef
.allocatedPtr(rewriter
, loc
));
1379 targetMemRef
.setAlignedPtr(rewriter
, loc
,
1380 viewMemRef
.alignedPtr(rewriter
, loc
));
1382 // Copy the offset pointer from the old descriptor to the new one.
1383 targetMemRef
.setOffset(rewriter
, loc
, viewMemRef
.offset(rewriter
, loc
));
1385 // Iterate over the dimensions and apply size/stride permutation:
1386 // When enumerating the results of the permutation map, the enumeration
1387 // index is the index into the target dimensions and the DimExpr points to
1388 // the dimension of the source memref.
1389 for (const auto &en
:
1390 llvm::enumerate(transposeOp
.getPermutation().getResults())) {
1391 int targetPos
= en
.index();
1392 int sourcePos
= cast
<AffineDimExpr
>(en
.value()).getPosition();
1393 targetMemRef
.setSize(rewriter
, loc
, targetPos
,
1394 viewMemRef
.size(rewriter
, loc
, sourcePos
));
1395 targetMemRef
.setStride(rewriter
, loc
, targetPos
,
1396 viewMemRef
.stride(rewriter
, loc
, sourcePos
));
1399 rewriter
.replaceOp(transposeOp
, {targetMemRef
});
1404 /// Conversion pattern that transforms an op into:
1405 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1406 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1408 /// The view op is replaced by the descriptor.
1409 struct ViewOpLowering
: public ConvertOpToLLVMPattern
<memref::ViewOp
> {
1410 using ConvertOpToLLVMPattern
<memref::ViewOp
>::ConvertOpToLLVMPattern
;
1412 // Build and return the value for the idx^th shape dimension, either by
1413 // returning the constant shape dimension or counting the proper dynamic size.
1414 Value
getSize(ConversionPatternRewriter
&rewriter
, Location loc
,
1415 ArrayRef
<int64_t> shape
, ValueRange dynamicSizes
, unsigned idx
,
1416 Type indexType
) const {
1417 assert(idx
< shape
.size());
1418 if (!ShapedType::isDynamic(shape
[idx
]))
1419 return createIndexAttrConstant(rewriter
, loc
, indexType
, shape
[idx
]);
1420 // Count the number of dynamic dims in range [0, idx]
1422 llvm::count_if(shape
.take_front(idx
), ShapedType::isDynamic
);
1423 return dynamicSizes
[nDynamic
];
1426 // Build and return the idx^th stride, either by returning the constant stride
1427 // or by computing the dynamic stride from the current `runningStride` and
1428 // `nextSize`. The caller should keep a running stride and update it with the
1429 // result returned by this function.
1430 Value
getStride(ConversionPatternRewriter
&rewriter
, Location loc
,
1431 ArrayRef
<int64_t> strides
, Value nextSize
,
1432 Value runningStride
, unsigned idx
, Type indexType
) const {
1433 assert(idx
< strides
.size());
1434 if (!ShapedType::isDynamic(strides
[idx
]))
1435 return createIndexAttrConstant(rewriter
, loc
, indexType
, strides
[idx
]);
1437 return runningStride
1438 ? rewriter
.create
<LLVM::MulOp
>(loc
, runningStride
, nextSize
)
1440 assert(!runningStride
);
1441 return createIndexAttrConstant(rewriter
, loc
, indexType
, 1);
1445 matchAndRewrite(memref::ViewOp viewOp
, OpAdaptor adaptor
,
1446 ConversionPatternRewriter
&rewriter
) const override
{
1447 auto loc
= viewOp
.getLoc();
1449 auto viewMemRefType
= viewOp
.getType();
1450 auto targetElementTy
=
1451 typeConverter
->convertType(viewMemRefType
.getElementType());
1452 auto targetDescTy
= typeConverter
->convertType(viewMemRefType
);
1453 if (!targetDescTy
|| !targetElementTy
||
1454 !LLVM::isCompatibleType(targetElementTy
) ||
1455 !LLVM::isCompatibleType(targetDescTy
))
1456 return viewOp
.emitWarning("Target descriptor type not converted to LLVM"),
1460 SmallVector
<int64_t, 4> strides
;
1461 auto successStrides
= getStridesAndOffset(viewMemRefType
, strides
, offset
);
1462 if (failed(successStrides
))
1463 return viewOp
.emitWarning("cannot cast to non-strided shape"), failure();
1464 assert(offset
== 0 && "expected offset to be 0");
1466 // Target memref must be contiguous in memory (innermost stride is 1), or
1467 // empty (special case when at least one of the memref dimensions is 0).
1468 if (!strides
.empty() && (strides
.back() != 1 && strides
.back() != 0))
1469 return viewOp
.emitWarning("cannot cast to non-contiguous shape"),
1472 // Create the descriptor.
1473 MemRefDescriptor
sourceMemRef(adaptor
.getSource());
1474 auto targetMemRef
= MemRefDescriptor::undef(rewriter
, loc
, targetDescTy
);
1476 // Field 1: Copy the allocated pointer, used for malloc/free.
1477 Value allocatedPtr
= sourceMemRef
.allocatedPtr(rewriter
, loc
);
1478 auto srcMemRefType
= cast
<MemRefType
>(viewOp
.getSource().getType());
1479 targetMemRef
.setAllocatedPtr(rewriter
, loc
, allocatedPtr
);
1481 // Field 2: Copy the actual aligned pointer to payload.
1482 Value alignedPtr
= sourceMemRef
.alignedPtr(rewriter
, loc
);
1483 alignedPtr
= rewriter
.create
<LLVM::GEPOp
>(
1484 loc
, alignedPtr
.getType(),
1485 typeConverter
->convertType(srcMemRefType
.getElementType()), alignedPtr
,
1486 adaptor
.getByteShift());
1488 targetMemRef
.setAlignedPtr(rewriter
, loc
, alignedPtr
);
1490 Type indexType
= getIndexType();
1491 // Field 3: The offset in the resulting type must be 0. This is
1492 // because of the type change: an offset on srcType* may not be
1493 // expressible as an offset on dstType*.
1494 targetMemRef
.setOffset(
1496 createIndexAttrConstant(rewriter
, loc
, indexType
, offset
));
1498 // Early exit for 0-D corner case.
1499 if (viewMemRefType
.getRank() == 0)
1500 return rewriter
.replaceOp(viewOp
, {targetMemRef
}), success();
1502 // Fields 4 and 5: Update sizes and strides.
1503 Value stride
= nullptr, nextSize
= nullptr;
1504 for (int i
= viewMemRefType
.getRank() - 1; i
>= 0; --i
) {
1506 Value size
= getSize(rewriter
, loc
, viewMemRefType
.getShape(),
1507 adaptor
.getSizes(), i
, indexType
);
1508 targetMemRef
.setSize(rewriter
, loc
, i
, size
);
1511 getStride(rewriter
, loc
, strides
, nextSize
, stride
, i
, indexType
);
1512 targetMemRef
.setStride(rewriter
, loc
, i
, stride
);
1516 rewriter
.replaceOp(viewOp
, {targetMemRef
});
1521 //===----------------------------------------------------------------------===//
1522 // AtomicRMWOpLowering
1523 //===----------------------------------------------------------------------===//
1525 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1526 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1527 static std::optional
<LLVM::AtomicBinOp
>
1528 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp
) {
1529 switch (atomicOp
.getKind()) {
1530 case arith::AtomicRMWKind::addf
:
1531 return LLVM::AtomicBinOp::fadd
;
1532 case arith::AtomicRMWKind::addi
:
1533 return LLVM::AtomicBinOp::add
;
1534 case arith::AtomicRMWKind::assign
:
1535 return LLVM::AtomicBinOp::xchg
;
1536 case arith::AtomicRMWKind::maximumf
:
1537 return LLVM::AtomicBinOp::fmax
;
1538 case arith::AtomicRMWKind::maxs
:
1539 return LLVM::AtomicBinOp::max
;
1540 case arith::AtomicRMWKind::maxu
:
1541 return LLVM::AtomicBinOp::umax
;
1542 case arith::AtomicRMWKind::minimumf
:
1543 return LLVM::AtomicBinOp::fmin
;
1544 case arith::AtomicRMWKind::mins
:
1545 return LLVM::AtomicBinOp::min
;
1546 case arith::AtomicRMWKind::minu
:
1547 return LLVM::AtomicBinOp::umin
;
1548 case arith::AtomicRMWKind::ori
:
1549 return LLVM::AtomicBinOp::_or
;
1550 case arith::AtomicRMWKind::andi
:
1551 return LLVM::AtomicBinOp::_and
;
1553 return std::nullopt
;
1555 llvm_unreachable("Invalid AtomicRMWKind");
1558 struct AtomicRMWOpLowering
: public LoadStoreOpLowering
<memref::AtomicRMWOp
> {
1562 matchAndRewrite(memref::AtomicRMWOp atomicOp
, OpAdaptor adaptor
,
1563 ConversionPatternRewriter
&rewriter
) const override
{
1564 auto maybeKind
= matchSimpleAtomicOp(atomicOp
);
1567 auto memRefType
= atomicOp
.getMemRefType();
1568 SmallVector
<int64_t> strides
;
1570 if (failed(getStridesAndOffset(memRefType
, strides
, offset
)))
1573 getStridedElementPtr(atomicOp
.getLoc(), memRefType
, adaptor
.getMemref(),
1574 adaptor
.getIndices(), rewriter
);
1575 rewriter
.replaceOpWithNewOp
<LLVM::AtomicRMWOp
>(
1576 atomicOp
, *maybeKind
, dataPtr
, adaptor
.getValue(),
1577 LLVM::AtomicOrdering::acq_rel
);
1582 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1583 class ConvertExtractAlignedPointerAsIndex
1584 : public ConvertOpToLLVMPattern
<memref::ExtractAlignedPointerAsIndexOp
> {
1586 using ConvertOpToLLVMPattern
<
1587 memref::ExtractAlignedPointerAsIndexOp
>::ConvertOpToLLVMPattern
;
1590 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp
,
1592 ConversionPatternRewriter
&rewriter
) const override
{
1593 BaseMemRefType sourceTy
= extractOp
.getSource().getType();
1596 if (sourceTy
.hasRank()) {
1597 MemRefDescriptor
desc(adaptor
.getSource());
1598 alignedPtr
= desc
.alignedPtr(rewriter
, extractOp
->getLoc());
1600 auto elementPtrTy
= LLVM::LLVMPointerType::get(
1601 rewriter
.getContext(), sourceTy
.getMemorySpaceAsInt());
1603 UnrankedMemRefDescriptor
desc(adaptor
.getSource());
1604 Value descPtr
= desc
.memRefDescPtr(rewriter
, extractOp
->getLoc());
1606 alignedPtr
= UnrankedMemRefDescriptor::alignedPtr(
1607 rewriter
, extractOp
->getLoc(), *getTypeConverter(), descPtr
,
1611 rewriter
.replaceOpWithNewOp
<LLVM::PtrToIntOp
>(
1612 extractOp
, getTypeConverter()->getIndexType(), alignedPtr
);
1617 /// Materialize the MemRef descriptor represented by the results of
1618 /// ExtractStridedMetadataOp.
1619 class ExtractStridedMetadataOpLowering
1620 : public ConvertOpToLLVMPattern
<memref::ExtractStridedMetadataOp
> {
1622 using ConvertOpToLLVMPattern
<
1623 memref::ExtractStridedMetadataOp
>::ConvertOpToLLVMPattern
;
1626 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp
,
1628 ConversionPatternRewriter
&rewriter
) const override
{
1630 if (!LLVM::isCompatibleType(adaptor
.getOperands().front().getType()))
1633 // Create the descriptor.
1634 MemRefDescriptor
sourceMemRef(adaptor
.getSource());
1635 Location loc
= extractStridedMetadataOp
.getLoc();
1636 Value source
= extractStridedMetadataOp
.getSource();
1638 auto sourceMemRefType
= cast
<MemRefType
>(source
.getType());
1639 int64_t rank
= sourceMemRefType
.getRank();
1640 SmallVector
<Value
> results
;
1641 results
.reserve(2 + rank
* 2);
1644 Value baseBuffer
= sourceMemRef
.allocatedPtr(rewriter
, loc
);
1645 Value alignedBuffer
= sourceMemRef
.alignedPtr(rewriter
, loc
);
1646 MemRefDescriptor dstMemRef
= MemRefDescriptor::fromStaticShape(
1647 rewriter
, loc
, *getTypeConverter(),
1648 cast
<MemRefType
>(extractStridedMetadataOp
.getBaseBuffer().getType()),
1649 baseBuffer
, alignedBuffer
);
1650 results
.push_back((Value
)dstMemRef
);
1653 results
.push_back(sourceMemRef
.offset(rewriter
, loc
));
1656 for (unsigned i
= 0; i
< rank
; ++i
)
1657 results
.push_back(sourceMemRef
.size(rewriter
, loc
, i
));
1659 for (unsigned i
= 0; i
< rank
; ++i
)
1660 results
.push_back(sourceMemRef
.stride(rewriter
, loc
, i
));
1662 rewriter
.replaceOp(extractStridedMetadataOp
, results
);
1669 void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1670 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
) {
1674 AllocaScopeOpLowering
,
1675 AtomicRMWOpLowering
,
1676 AssumeAlignmentOpLowering
,
1677 ConvertExtractAlignedPointerAsIndex
,
1679 ExtractStridedMetadataOpLowering
,
1680 GenericAtomicRMWOpLowering
,
1681 GlobalMemrefOpLowering
,
1682 GetGlobalMemrefOpLowering
,
1684 MemRefCastOpLowering
,
1685 MemRefCopyOpLowering
,
1686 MemorySpaceCastOpLowering
,
1687 MemRefReinterpretCastOpLowering
,
1688 MemRefReshapeOpLowering
,
1691 ReassociatingReshapeOpConversion
<memref::ExpandShapeOp
>,
1692 ReassociatingReshapeOpConversion
<memref::CollapseShapeOp
>,
1695 TransposeOpLowering
,
1696 ViewOpLowering
>(converter
);
1698 auto allocLowering
= converter
.getOptions().allocLowering
;
1699 if (allocLowering
== LowerToLLVMOptions::AllocLowering::AlignedAlloc
)
1700 patterns
.add
<AlignedAllocOpLowering
, DeallocOpLowering
>(converter
);
1701 else if (allocLowering
== LowerToLLVMOptions::AllocLowering::Malloc
)
1702 patterns
.add
<AllocOpLowering
, DeallocOpLowering
>(converter
);
1706 struct FinalizeMemRefToLLVMConversionPass
1707 : public impl::FinalizeMemRefToLLVMConversionPassBase
<
1708 FinalizeMemRefToLLVMConversionPass
> {
1709 using FinalizeMemRefToLLVMConversionPassBase::
1710 FinalizeMemRefToLLVMConversionPassBase
;
1712 void runOnOperation() override
{
1713 Operation
*op
= getOperation();
1714 const auto &dataLayoutAnalysis
= getAnalysis
<DataLayoutAnalysis
>();
1715 LowerToLLVMOptions
options(&getContext(),
1716 dataLayoutAnalysis
.getAtOrAbove(op
));
1717 options
.allocLowering
=
1718 (useAlignedAlloc
? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1719 : LowerToLLVMOptions::AllocLowering::Malloc
);
1721 options
.useGenericFunctions
= useGenericFunctions
;
1723 if (indexBitwidth
!= kDeriveIndexBitwidthFromDataLayout
)
1724 options
.overrideIndexBitwidth(indexBitwidth
);
1726 LLVMTypeConverter
typeConverter(&getContext(), options
,
1727 &dataLayoutAnalysis
);
1728 RewritePatternSet
patterns(&getContext());
1729 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter
, patterns
);
1730 LLVMConversionTarget
target(getContext());
1731 target
.addLegalOp
<func::FuncOp
>();
1732 if (failed(applyPartialConversion(op
, target
, std::move(patterns
))))
1733 signalPassFailure();
1737 /// Implement the interface to convert MemRef to LLVM.
1738 struct MemRefToLLVMDialectInterface
: public ConvertToLLVMPatternInterface
{
1739 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface
;
1740 void loadDependentDialects(MLIRContext
*context
) const final
{
1741 context
->loadDialect
<LLVM::LLVMDialect
>();
1744 /// Hook for derived dialect interface to provide conversion patterns
1745 /// and mark dialect legal for the conversion target.
1746 void populateConvertToLLVMConversionPatterns(
1747 ConversionTarget
&target
, LLVMTypeConverter
&typeConverter
,
1748 RewritePatternSet
&patterns
) const final
{
1749 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter
, patterns
);
1755 void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry
®istry
) {
1756 registry
.addExtension(+[](MLIRContext
*ctx
, memref::MemRefDialect
*dialect
) {
1757 dialect
->addInterfaces
<MemRefToLLVMDialectInterface
>();