1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/IRMapping.h"
27 #include "mlir/Pass/Pass.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/Support/MathExtras.h"
33 #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34 #include "mlir/Conversion/Passes.h.inc"
41 bool isStaticStrideOrOffset(int64_t strideOrOffset
) {
42 return !ShapedType::isDynamic(strideOrOffset
);
45 LLVM::LLVMFuncOp
getFreeFn(const LLVMTypeConverter
*typeConverter
,
47 bool useGenericFn
= typeConverter
->getOptions().useGenericFunctions
;
50 return LLVM::lookupOrCreateGenericFreeFn(module
);
52 return LLVM::lookupOrCreateFreeFn(module
);
55 struct AllocOpLowering
: public AllocLikeOpLLVMLowering
{
56 AllocOpLowering(const LLVMTypeConverter
&converter
)
57 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
59 std::tuple
<Value
, Value
> allocateBuffer(ConversionPatternRewriter
&rewriter
,
60 Location loc
, Value sizeBytes
,
61 Operation
*op
) const override
{
62 return allocateBufferManuallyAlign(
63 rewriter
, loc
, sizeBytes
, op
,
64 getAlignment(rewriter
, loc
, cast
<memref::AllocOp
>(op
)));
68 struct AlignedAllocOpLowering
: public AllocLikeOpLLVMLowering
{
69 AlignedAllocOpLowering(const LLVMTypeConverter
&converter
)
70 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
72 std::tuple
<Value
, Value
> allocateBuffer(ConversionPatternRewriter
&rewriter
,
73 Location loc
, Value sizeBytes
,
74 Operation
*op
) const override
{
75 Value ptr
= allocateBufferAutoAlign(
76 rewriter
, loc
, sizeBytes
, op
, &defaultLayout
,
77 alignedAllocationGetAlignment(rewriter
, loc
, cast
<memref::AllocOp
>(op
),
80 return std::make_tuple(Value(), Value());
81 return std::make_tuple(ptr
, ptr
);
85 /// Default layout to use in absence of the corresponding analysis.
86 DataLayout defaultLayout
;
89 struct AllocaOpLowering
: public AllocLikeOpLLVMLowering
{
90 AllocaOpLowering(const LLVMTypeConverter
&converter
)
91 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
93 setRequiresNumElements();
96 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97 /// is set to null for stack allocations. `accessAlignment` is set if
98 /// alignment is needed post allocation (for eg. in conjunction with malloc).
99 std::tuple
<Value
, Value
> allocateBuffer(ConversionPatternRewriter
&rewriter
,
100 Location loc
, Value size
,
101 Operation
*op
) const override
{
103 // With alloca, one gets a pointer to the element type right away.
104 // For stack allocations.
105 auto allocaOp
= cast
<memref::AllocaOp
>(op
);
107 typeConverter
->convertType(allocaOp
.getType().getElementType());
109 *getTypeConverter()->getMemRefAddressSpace(allocaOp
.getType());
110 auto elementPtrType
=
111 LLVM::LLVMPointerType::get(rewriter
.getContext(), addrSpace
);
113 auto allocatedElementPtr
=
114 rewriter
.create
<LLVM::AllocaOp
>(loc
, elementPtrType
, elementType
, size
,
115 allocaOp
.getAlignment().value_or(0));
117 return std::make_tuple(allocatedElementPtr
, allocatedElementPtr
);
121 struct AllocaScopeOpLowering
122 : public ConvertOpToLLVMPattern
<memref::AllocaScopeOp
> {
123 using ConvertOpToLLVMPattern
<memref::AllocaScopeOp
>::ConvertOpToLLVMPattern
;
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp
, OpAdaptor adaptor
,
127 ConversionPatternRewriter
&rewriter
) const override
{
128 OpBuilder::InsertionGuard
guard(rewriter
);
129 Location loc
= allocaScopeOp
.getLoc();
131 // Split the current block before the AllocaScopeOp to create the inlining
133 auto *currentBlock
= rewriter
.getInsertionBlock();
134 auto *remainingOpsBlock
=
135 rewriter
.splitBlock(currentBlock
, rewriter
.getInsertionPoint());
136 Block
*continueBlock
;
137 if (allocaScopeOp
.getNumResults() == 0) {
138 continueBlock
= remainingOpsBlock
;
140 continueBlock
= rewriter
.createBlock(
141 remainingOpsBlock
, allocaScopeOp
.getResultTypes(),
142 SmallVector
<Location
>(allocaScopeOp
->getNumResults(),
143 allocaScopeOp
.getLoc()));
144 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange(), remainingOpsBlock
);
147 // Inline body region.
148 Block
*beforeBody
= &allocaScopeOp
.getBodyRegion().front();
149 Block
*afterBody
= &allocaScopeOp
.getBodyRegion().back();
150 rewriter
.inlineRegionBefore(allocaScopeOp
.getBodyRegion(), continueBlock
);
152 // Save stack and then branch into the body of the region.
153 rewriter
.setInsertionPointToEnd(currentBlock
);
155 rewriter
.create
<LLVM::StackSaveOp
>(loc
, getVoidPtrType());
156 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange(), beforeBody
);
158 // Replace the alloca_scope return with a branch that jumps out of the body.
159 // Stack restore before leaving the body region.
160 rewriter
.setInsertionPointToEnd(afterBody
);
162 cast
<memref::AllocaScopeReturnOp
>(afterBody
->getTerminator());
163 auto branchOp
= rewriter
.replaceOpWithNewOp
<LLVM::BrOp
>(
164 returnOp
, returnOp
.getResults(), continueBlock
);
166 // Insert stack restore before jumping out the body of the region.
167 rewriter
.setInsertionPoint(branchOp
);
168 rewriter
.create
<LLVM::StackRestoreOp
>(loc
, stackSaveOp
);
170 // Replace the op with values return from the body region.
171 rewriter
.replaceOp(allocaScopeOp
, continueBlock
->getArguments());
177 struct AssumeAlignmentOpLowering
178 : public ConvertOpToLLVMPattern
<memref::AssumeAlignmentOp
> {
179 using ConvertOpToLLVMPattern
<
180 memref::AssumeAlignmentOp
>::ConvertOpToLLVMPattern
;
181 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter
&converter
)
182 : ConvertOpToLLVMPattern
<memref::AssumeAlignmentOp
>(converter
) {}
185 matchAndRewrite(memref::AssumeAlignmentOp op
, OpAdaptor adaptor
,
186 ConversionPatternRewriter
&rewriter
) const override
{
187 Value memref
= adaptor
.getMemref();
188 unsigned alignment
= op
.getAlignment();
189 auto loc
= op
.getLoc();
191 auto srcMemRefType
= cast
<MemRefType
>(op
.getMemref().getType());
192 Value ptr
= getStridedElementPtr(loc
, srcMemRefType
, memref
, /*indices=*/{},
195 // Emit llvm.assume(true) ["align"(memref, alignment)].
196 // This is more direct than ptrtoint-based checks, is explicitly supported,
197 // and works with non-integral address spaces.
199 rewriter
.create
<LLVM::ConstantOp
>(loc
, rewriter
.getBoolAttr(true));
200 Value alignmentConst
=
201 createIndexAttrConstant(rewriter
, loc
, getIndexType(), alignment
);
202 rewriter
.create
<LLVM::AssumeOp
>(loc
, trueCond
, LLVM::AssumeAlignTag(), ptr
,
205 rewriter
.eraseOp(op
);
210 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
211 // The memref descriptor being an SSA value, there is no need to clean it up
213 struct DeallocOpLowering
: public ConvertOpToLLVMPattern
<memref::DeallocOp
> {
214 using ConvertOpToLLVMPattern
<memref::DeallocOp
>::ConvertOpToLLVMPattern
;
216 explicit DeallocOpLowering(const LLVMTypeConverter
&converter
)
217 : ConvertOpToLLVMPattern
<memref::DeallocOp
>(converter
) {}
220 matchAndRewrite(memref::DeallocOp op
, OpAdaptor adaptor
,
221 ConversionPatternRewriter
&rewriter
) const override
{
222 // Insert the `free` declaration if it is not already present.
223 LLVM::LLVMFuncOp freeFunc
=
224 getFreeFn(getTypeConverter(), op
->getParentOfType
<ModuleOp
>());
226 if (auto unrankedTy
=
227 llvm::dyn_cast
<UnrankedMemRefType
>(op
.getMemref().getType())) {
228 auto elementPtrTy
= LLVM::LLVMPointerType::get(
229 rewriter
.getContext(), unrankedTy
.getMemorySpaceAsInt());
230 allocatedPtr
= UnrankedMemRefDescriptor::allocatedPtr(
231 rewriter
, op
.getLoc(),
232 UnrankedMemRefDescriptor(adaptor
.getMemref())
233 .memRefDescPtr(rewriter
, op
.getLoc()),
236 allocatedPtr
= MemRefDescriptor(adaptor
.getMemref())
237 .allocatedPtr(rewriter
, op
.getLoc());
239 rewriter
.replaceOpWithNewOp
<LLVM::CallOp
>(op
, freeFunc
, allocatedPtr
);
244 // A `dim` is converted to a constant for static sizes and to an access to the
245 // size stored in the memref descriptor for dynamic sizes.
246 struct DimOpLowering
: public ConvertOpToLLVMPattern
<memref::DimOp
> {
247 using ConvertOpToLLVMPattern
<memref::DimOp
>::ConvertOpToLLVMPattern
;
250 matchAndRewrite(memref::DimOp dimOp
, OpAdaptor adaptor
,
251 ConversionPatternRewriter
&rewriter
) const override
{
252 Type operandType
= dimOp
.getSource().getType();
253 if (isa
<UnrankedMemRefType
>(operandType
)) {
254 FailureOr
<Value
> extractedSize
= extractSizeOfUnrankedMemRef(
255 operandType
, dimOp
, adaptor
.getOperands(), rewriter
);
256 if (failed(extractedSize
))
258 rewriter
.replaceOp(dimOp
, {*extractedSize
});
261 if (isa
<MemRefType
>(operandType
)) {
263 dimOp
, {extractSizeOfRankedMemRef(operandType
, dimOp
,
264 adaptor
.getOperands(), rewriter
)});
267 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
272 extractSizeOfUnrankedMemRef(Type operandType
, memref::DimOp dimOp
,
274 ConversionPatternRewriter
&rewriter
) const {
275 Location loc
= dimOp
.getLoc();
277 auto unrankedMemRefType
= cast
<UnrankedMemRefType
>(operandType
);
278 auto scalarMemRefType
=
279 MemRefType::get({}, unrankedMemRefType
.getElementType());
280 FailureOr
<unsigned> maybeAddressSpace
=
281 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType
);
282 if (failed(maybeAddressSpace
)) {
283 dimOp
.emitOpError("memref memory space must be convertible to an integer "
287 unsigned addressSpace
= *maybeAddressSpace
;
289 // Extract pointer to the underlying ranked descriptor and bitcast it to a
290 // memref<element_type> descriptor pointer to minimize the number of GEP
292 UnrankedMemRefDescriptor
unrankedDesc(adaptor
.getSource());
293 Value underlyingRankedDesc
= unrankedDesc
.memRefDescPtr(rewriter
, loc
);
295 Type elementType
= typeConverter
->convertType(scalarMemRefType
);
297 // Get pointer to offset field of memref<element_type> descriptor.
299 LLVM::LLVMPointerType::get(rewriter
.getContext(), addressSpace
);
300 Value offsetPtr
= rewriter
.create
<LLVM::GEPOp
>(
301 loc
, indexPtrTy
, elementType
, underlyingRankedDesc
,
302 ArrayRef
<LLVM::GEPArg
>{0, 2});
304 // The size value that we have to extract can be obtained using GEPop with
305 // `dimOp.index() + 1` index argument.
306 Value idxPlusOne
= rewriter
.create
<LLVM::AddOp
>(
307 loc
, createIndexAttrConstant(rewriter
, loc
, getIndexType(), 1),
309 Value sizePtr
= rewriter
.create
<LLVM::GEPOp
>(
310 loc
, indexPtrTy
, getTypeConverter()->getIndexType(), offsetPtr
,
313 .create
<LLVM::LoadOp
>(loc
, getTypeConverter()->getIndexType(), sizePtr
)
317 std::optional
<int64_t> getConstantDimIndex(memref::DimOp dimOp
) const {
318 if (auto idx
= dimOp
.getConstantIndex())
321 if (auto constantOp
= dimOp
.getIndex().getDefiningOp
<LLVM::ConstantOp
>())
322 return cast
<IntegerAttr
>(constantOp
.getValue()).getValue().getSExtValue();
327 Value
extractSizeOfRankedMemRef(Type operandType
, memref::DimOp dimOp
,
329 ConversionPatternRewriter
&rewriter
) const {
330 Location loc
= dimOp
.getLoc();
332 // Take advantage if index is constant.
333 MemRefType memRefType
= cast
<MemRefType
>(operandType
);
334 Type indexType
= getIndexType();
335 if (std::optional
<int64_t> index
= getConstantDimIndex(dimOp
)) {
337 if (i
>= 0 && i
< memRefType
.getRank()) {
338 if (memRefType
.isDynamicDim(i
)) {
339 // extract dynamic size from the memref descriptor.
340 MemRefDescriptor
descriptor(adaptor
.getSource());
341 return descriptor
.size(rewriter
, loc
, i
);
343 // Use constant for static size.
344 int64_t dimSize
= memRefType
.getDimSize(i
);
345 return createIndexAttrConstant(rewriter
, loc
, indexType
, dimSize
);
348 Value index
= adaptor
.getIndex();
349 int64_t rank
= memRefType
.getRank();
350 MemRefDescriptor
memrefDescriptor(adaptor
.getSource());
351 return memrefDescriptor
.size(rewriter
, loc
, index
, rank
);
355 /// Common base for load and store operations on MemRefs. Restricts the match
356 /// to supported MemRef types. Provides functionality to emit code accessing a
357 /// specific element of the underlying data buffer.
358 template <typename Derived
>
359 struct LoadStoreOpLowering
: public ConvertOpToLLVMPattern
<Derived
> {
360 using ConvertOpToLLVMPattern
<Derived
>::ConvertOpToLLVMPattern
;
361 using ConvertOpToLLVMPattern
<Derived
>::isConvertibleAndHasIdentityMaps
;
362 using Base
= LoadStoreOpLowering
<Derived
>;
364 LogicalResult
match(Derived op
) const override
{
365 MemRefType type
= op
.getMemRefType();
366 return isConvertibleAndHasIdentityMaps(type
) ? success() : failure();
370 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
371 /// retried until it succeeds in atomically storing a new value into memory.
373 /// +---------------------------------+
374 /// | <code before the AtomicRMWOp> |
375 /// | <compute initial %loaded> |
376 /// | cf.br loop(%loaded) |
377 /// +---------------------------------+
381 /// | +--------------------------------+
382 /// | | loop(%loaded): |
383 /// | | <body contents> |
384 /// | | %pair = cmpxchg |
385 /// | | %ok = %pair[0] |
386 /// | | %new = %pair[1] |
387 /// | | cf.cond_br %ok, end, loop(%new) |
388 /// | +--------------------------------+
392 /// +--------------------------------+
394 /// | <code after the AtomicRMWOp> |
395 /// +--------------------------------+
397 struct GenericAtomicRMWOpLowering
398 : public LoadStoreOpLowering
<memref::GenericAtomicRMWOp
> {
402 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp
, OpAdaptor adaptor
,
403 ConversionPatternRewriter
&rewriter
) const override
{
404 auto loc
= atomicOp
.getLoc();
405 Type valueType
= typeConverter
->convertType(atomicOp
.getResult().getType());
407 // Split the block into initial, loop, and ending parts.
408 auto *initBlock
= rewriter
.getInsertionBlock();
409 auto *loopBlock
= rewriter
.splitBlock(initBlock
, Block::iterator(atomicOp
));
410 loopBlock
->addArgument(valueType
, loc
);
413 rewriter
.splitBlock(loopBlock
, Block::iterator(atomicOp
)++);
415 // Compute the loaded value and branch to the loop block.
416 rewriter
.setInsertionPointToEnd(initBlock
);
417 auto memRefType
= cast
<MemRefType
>(atomicOp
.getMemref().getType());
418 auto dataPtr
= getStridedElementPtr(loc
, memRefType
, adaptor
.getMemref(),
419 adaptor
.getIndices(), rewriter
);
420 Value init
= rewriter
.create
<LLVM::LoadOp
>(
421 loc
, typeConverter
->convertType(memRefType
.getElementType()), dataPtr
);
422 rewriter
.create
<LLVM::BrOp
>(loc
, init
, loopBlock
);
424 // Prepare the body of the loop block.
425 rewriter
.setInsertionPointToStart(loopBlock
);
427 // Clone the GenericAtomicRMWOp region and extract the result.
428 auto loopArgument
= loopBlock
->getArgument(0);
430 mapping
.map(atomicOp
.getCurrentValue(), loopArgument
);
431 Block
&entryBlock
= atomicOp
.body().front();
432 for (auto &nestedOp
: entryBlock
.without_terminator()) {
433 Operation
*clone
= rewriter
.clone(nestedOp
, mapping
);
434 mapping
.map(nestedOp
.getResults(), clone
->getResults());
436 Value result
= mapping
.lookup(entryBlock
.getTerminator()->getOperand(0));
438 // Prepare the epilog of the loop block.
439 // Append the cmpxchg op to the end of the loop block.
440 auto successOrdering
= LLVM::AtomicOrdering::acq_rel
;
441 auto failureOrdering
= LLVM::AtomicOrdering::monotonic
;
442 auto cmpxchg
= rewriter
.create
<LLVM::AtomicCmpXchgOp
>(
443 loc
, dataPtr
, loopArgument
, result
, successOrdering
, failureOrdering
);
444 // Extract the %new_loaded and %ok values from the pair.
445 Value newLoaded
= rewriter
.create
<LLVM::ExtractValueOp
>(loc
, cmpxchg
, 0);
446 Value ok
= rewriter
.create
<LLVM::ExtractValueOp
>(loc
, cmpxchg
, 1);
448 // Conditionally branch to the end or back to the loop depending on %ok.
449 rewriter
.create
<LLVM::CondBrOp
>(loc
, ok
, endBlock
, ArrayRef
<Value
>(),
450 loopBlock
, newLoaded
);
452 rewriter
.setInsertionPointToEnd(endBlock
);
454 // The 'result' of the atomic_rmw op is the newly loaded value.
455 rewriter
.replaceOp(atomicOp
, {newLoaded
});
461 /// Returns the LLVM type of the global variable given the memref type `type`.
463 convertGlobalMemrefTypeToLLVM(MemRefType type
,
464 const LLVMTypeConverter
&typeConverter
) {
465 // LLVM type for a global memref will be a multi-dimension array. For
466 // declarations or uninitialized global memrefs, we can potentially flatten
467 // this to a 1D array. However, for memref.global's with an initial value,
468 // we do not intend to flatten the ElementsAttribute when going from std ->
469 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
470 Type elementType
= typeConverter
.convertType(type
.getElementType());
471 Type arrayTy
= elementType
;
472 // Shape has the outermost dim at index 0, so need to walk it backwards
473 for (int64_t dim
: llvm::reverse(type
.getShape()))
474 arrayTy
= LLVM::LLVMArrayType::get(arrayTy
, dim
);
478 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
479 struct GlobalMemrefOpLowering
480 : public ConvertOpToLLVMPattern
<memref::GlobalOp
> {
481 using ConvertOpToLLVMPattern
<memref::GlobalOp
>::ConvertOpToLLVMPattern
;
484 matchAndRewrite(memref::GlobalOp global
, OpAdaptor adaptor
,
485 ConversionPatternRewriter
&rewriter
) const override
{
486 MemRefType type
= global
.getType();
487 if (!isConvertibleAndHasIdentityMaps(type
))
490 Type arrayTy
= convertGlobalMemrefTypeToLLVM(type
, *getTypeConverter());
492 LLVM::Linkage linkage
=
493 global
.isPublic() ? LLVM::Linkage::External
: LLVM::Linkage::Private
;
495 Attribute initialValue
= nullptr;
496 if (!global
.isExternal() && !global
.isUninitialized()) {
497 auto elementsAttr
= llvm::cast
<ElementsAttr
>(*global
.getInitialValue());
498 initialValue
= elementsAttr
;
500 // For scalar memrefs, the global variable created is of the element type,
501 // so unpack the elements attribute to extract the value.
502 if (type
.getRank() == 0)
503 initialValue
= elementsAttr
.getSplatValue
<Attribute
>();
506 uint64_t alignment
= global
.getAlignment().value_or(0);
507 FailureOr
<unsigned> addressSpace
=
508 getTypeConverter()->getMemRefAddressSpace(type
);
509 if (failed(addressSpace
))
510 return global
.emitOpError(
511 "memory space cannot be converted to an integer address space");
512 auto newGlobal
= rewriter
.replaceOpWithNewOp
<LLVM::GlobalOp
>(
513 global
, arrayTy
, global
.getConstant(), linkage
, global
.getSymName(),
514 initialValue
, alignment
, *addressSpace
);
515 if (!global
.isExternal() && global
.isUninitialized()) {
516 rewriter
.createBlock(&newGlobal
.getInitializerRegion());
518 rewriter
.create
<LLVM::UndefOp
>(global
.getLoc(), arrayTy
)};
519 rewriter
.create
<LLVM::ReturnOp
>(global
.getLoc(), undef
);
525 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
526 /// the first element stashed into the descriptor. This reuses
527 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
528 struct GetGlobalMemrefOpLowering
: public AllocLikeOpLLVMLowering
{
529 GetGlobalMemrefOpLowering(const LLVMTypeConverter
&converter
)
530 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
533 /// Buffer "allocation" for memref.get_global op is getting the address of
534 /// the global variable referenced.
535 std::tuple
<Value
, Value
> allocateBuffer(ConversionPatternRewriter
&rewriter
,
536 Location loc
, Value sizeBytes
,
537 Operation
*op
) const override
{
538 auto getGlobalOp
= cast
<memref::GetGlobalOp
>(op
);
539 MemRefType type
= cast
<MemRefType
>(getGlobalOp
.getResult().getType());
541 // This is called after a type conversion, which would have failed if this
543 FailureOr
<unsigned> maybeAddressSpace
=
544 getTypeConverter()->getMemRefAddressSpace(type
);
545 if (failed(maybeAddressSpace
))
546 return std::make_tuple(Value(), Value());
547 unsigned memSpace
= *maybeAddressSpace
;
549 Type arrayTy
= convertGlobalMemrefTypeToLLVM(type
, *getTypeConverter());
550 auto ptrTy
= LLVM::LLVMPointerType::get(rewriter
.getContext(), memSpace
);
552 rewriter
.create
<LLVM::AddressOfOp
>(loc
, ptrTy
, getGlobalOp
.getName());
554 // Get the address of the first element in the array by creating a GEP with
555 // the address of the GV as the base, and (rank + 1) number of 0 indices.
556 auto gep
= rewriter
.create
<LLVM::GEPOp
>(
557 loc
, ptrTy
, arrayTy
, addressOf
,
558 SmallVector
<LLVM::GEPArg
>(type
.getRank() + 1, 0));
560 // We do not expect the memref obtained using `memref.get_global` to be
561 // ever deallocated. Set the allocated pointer to be known bad value to
562 // help debug if that ever happens.
563 auto intPtrType
= getIntPtrType(memSpace
);
564 Value deadBeefConst
=
565 createIndexAttrConstant(rewriter
, op
->getLoc(), intPtrType
, 0xdeadbeef);
567 rewriter
.create
<LLVM::IntToPtrOp
>(loc
, ptrTy
, deadBeefConst
);
569 // Both allocated and aligned pointers are same. We could potentially stash
570 // a nullptr for the allocated pointer since we do not expect any dealloc.
571 return std::make_tuple(deadBeefPtr
, gep
);
575 // Load operation is lowered to obtaining a pointer to the indexed element
577 struct LoadOpLowering
: public LoadStoreOpLowering
<memref::LoadOp
> {
581 matchAndRewrite(memref::LoadOp loadOp
, OpAdaptor adaptor
,
582 ConversionPatternRewriter
&rewriter
) const override
{
583 auto type
= loadOp
.getMemRefType();
586 getStridedElementPtr(loadOp
.getLoc(), type
, adaptor
.getMemref(),
587 adaptor
.getIndices(), rewriter
);
588 rewriter
.replaceOpWithNewOp
<LLVM::LoadOp
>(
589 loadOp
, typeConverter
->convertType(type
.getElementType()), dataPtr
, 0,
590 false, loadOp
.getNontemporal());
595 // Store operation is lowered to obtaining a pointer to the indexed element,
596 // and storing the given value to it.
597 struct StoreOpLowering
: public LoadStoreOpLowering
<memref::StoreOp
> {
601 matchAndRewrite(memref::StoreOp op
, OpAdaptor adaptor
,
602 ConversionPatternRewriter
&rewriter
) const override
{
603 auto type
= op
.getMemRefType();
605 Value dataPtr
= getStridedElementPtr(op
.getLoc(), type
, adaptor
.getMemref(),
606 adaptor
.getIndices(), rewriter
);
607 rewriter
.replaceOpWithNewOp
<LLVM::StoreOp
>(op
, adaptor
.getValue(), dataPtr
,
608 0, false, op
.getNontemporal());
613 // The prefetch operation is lowered in a way similar to the load operation
614 // except that the llvm.prefetch operation is used for replacement.
615 struct PrefetchOpLowering
: public LoadStoreOpLowering
<memref::PrefetchOp
> {
619 matchAndRewrite(memref::PrefetchOp prefetchOp
, OpAdaptor adaptor
,
620 ConversionPatternRewriter
&rewriter
) const override
{
621 auto type
= prefetchOp
.getMemRefType();
622 auto loc
= prefetchOp
.getLoc();
624 Value dataPtr
= getStridedElementPtr(loc
, type
, adaptor
.getMemref(),
625 adaptor
.getIndices(), rewriter
);
627 // Replace with llvm.prefetch.
628 IntegerAttr isWrite
= rewriter
.getI32IntegerAttr(prefetchOp
.getIsWrite());
629 IntegerAttr localityHint
= prefetchOp
.getLocalityHintAttr();
631 rewriter
.getI32IntegerAttr(prefetchOp
.getIsDataCache());
632 rewriter
.replaceOpWithNewOp
<LLVM::Prefetch
>(prefetchOp
, dataPtr
, isWrite
,
633 localityHint
, isData
);
638 struct RankOpLowering
: public ConvertOpToLLVMPattern
<memref::RankOp
> {
639 using ConvertOpToLLVMPattern
<memref::RankOp
>::ConvertOpToLLVMPattern
;
642 matchAndRewrite(memref::RankOp op
, OpAdaptor adaptor
,
643 ConversionPatternRewriter
&rewriter
) const override
{
644 Location loc
= op
.getLoc();
645 Type operandType
= op
.getMemref().getType();
646 if (dyn_cast
<UnrankedMemRefType
>(operandType
)) {
647 UnrankedMemRefDescriptor
desc(adaptor
.getMemref());
648 rewriter
.replaceOp(op
, {desc
.rank(rewriter
, loc
)});
651 if (auto rankedMemRefType
= dyn_cast
<MemRefType
>(operandType
)) {
652 Type indexType
= getIndexType();
653 rewriter
.replaceOp(op
,
654 {createIndexAttrConstant(rewriter
, loc
, indexType
,
655 rankedMemRefType
.getRank())});
662 struct MemRefCastOpLowering
: public ConvertOpToLLVMPattern
<memref::CastOp
> {
663 using ConvertOpToLLVMPattern
<memref::CastOp
>::ConvertOpToLLVMPattern
;
665 LogicalResult
match(memref::CastOp memRefCastOp
) const override
{
666 Type srcType
= memRefCastOp
.getOperand().getType();
667 Type dstType
= memRefCastOp
.getType();
669 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
670 // used for type erasure. For now they must preserve underlying element type
671 // and require source and result type to have the same rank. Therefore,
672 // perform a sanity check that the underlying structs are the same. Once op
673 // semantics are relaxed we can revisit.
674 if (isa
<MemRefType
>(srcType
) && isa
<MemRefType
>(dstType
))
675 return success(typeConverter
->convertType(srcType
) ==
676 typeConverter
->convertType(dstType
));
678 // At least one of the operands is unranked type
679 assert(isa
<UnrankedMemRefType
>(srcType
) ||
680 isa
<UnrankedMemRefType
>(dstType
));
682 // Unranked to unranked cast is disallowed
683 return !(isa
<UnrankedMemRefType
>(srcType
) &&
684 isa
<UnrankedMemRefType
>(dstType
))
689 void rewrite(memref::CastOp memRefCastOp
, OpAdaptor adaptor
,
690 ConversionPatternRewriter
&rewriter
) const override
{
691 auto srcType
= memRefCastOp
.getOperand().getType();
692 auto dstType
= memRefCastOp
.getType();
693 auto targetStructType
= typeConverter
->convertType(memRefCastOp
.getType());
694 auto loc
= memRefCastOp
.getLoc();
696 // For ranked/ranked case, just keep the original descriptor.
697 if (isa
<MemRefType
>(srcType
) && isa
<MemRefType
>(dstType
))
698 return rewriter
.replaceOp(memRefCastOp
, {adaptor
.getSource()});
700 if (isa
<MemRefType
>(srcType
) && isa
<UnrankedMemRefType
>(dstType
)) {
701 // Casting ranked to unranked memref type
702 // Set the rank in the destination from the memref type
703 // Allocate space on the stack and copy the src memref descriptor
704 // Set the ptr in the destination to the stack space
705 auto srcMemRefType
= cast
<MemRefType
>(srcType
);
706 int64_t rank
= srcMemRefType
.getRank();
707 // ptr = AllocaOp sizeof(MemRefDescriptor)
708 auto ptr
= getTypeConverter()->promoteOneMemRefDescriptor(
709 loc
, adaptor
.getSource(), rewriter
);
711 // rank = ConstantOp srcRank
712 auto rankVal
= rewriter
.create
<LLVM::ConstantOp
>(
713 loc
, getIndexType(), rewriter
.getIndexAttr(rank
));
715 UnrankedMemRefDescriptor memRefDesc
=
716 UnrankedMemRefDescriptor::undef(rewriter
, loc
, targetStructType
);
717 // d1 = InsertValueOp undef, rank, 0
718 memRefDesc
.setRank(rewriter
, loc
, rankVal
);
719 // d2 = InsertValueOp d1, ptr, 1
720 memRefDesc
.setMemRefDescPtr(rewriter
, loc
, ptr
);
721 rewriter
.replaceOp(memRefCastOp
, (Value
)memRefDesc
);
723 } else if (isa
<UnrankedMemRefType
>(srcType
) && isa
<MemRefType
>(dstType
)) {
724 // Casting from unranked type to ranked.
725 // The operation is assumed to be doing a correct cast. If the destination
726 // type mismatches the unranked the type, it is undefined behavior.
727 UnrankedMemRefDescriptor
memRefDesc(adaptor
.getSource());
728 // ptr = ExtractValueOp src, 1
729 auto ptr
= memRefDesc
.memRefDescPtr(rewriter
, loc
);
731 // struct = LoadOp ptr
732 auto loadOp
= rewriter
.create
<LLVM::LoadOp
>(loc
, targetStructType
, ptr
);
733 rewriter
.replaceOp(memRefCastOp
, loadOp
.getResult());
735 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
740 /// Pattern to lower a `memref.copy` to llvm.
742 /// For memrefs with identity layouts, the copy is lowered to the llvm
743 /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
744 /// to the generic `MemrefCopyFn`.
745 struct MemRefCopyOpLowering
: public ConvertOpToLLVMPattern
<memref::CopyOp
> {
746 using ConvertOpToLLVMPattern
<memref::CopyOp
>::ConvertOpToLLVMPattern
;
749 lowerToMemCopyIntrinsic(memref::CopyOp op
, OpAdaptor adaptor
,
750 ConversionPatternRewriter
&rewriter
) const {
751 auto loc
= op
.getLoc();
752 auto srcType
= dyn_cast
<MemRefType
>(op
.getSource().getType());
754 MemRefDescriptor
srcDesc(adaptor
.getSource());
756 // Compute number of elements.
757 Value numElements
= rewriter
.create
<LLVM::ConstantOp
>(
758 loc
, getIndexType(), rewriter
.getIndexAttr(1));
759 for (int pos
= 0; pos
< srcType
.getRank(); ++pos
) {
760 auto size
= srcDesc
.size(rewriter
, loc
, pos
);
761 numElements
= rewriter
.create
<LLVM::MulOp
>(loc
, numElements
, size
);
765 auto sizeInBytes
= getSizeInBytes(loc
, srcType
.getElementType(), rewriter
);
768 rewriter
.create
<LLVM::MulOp
>(loc
, numElements
, sizeInBytes
);
770 Type elementType
= typeConverter
->convertType(srcType
.getElementType());
772 Value srcBasePtr
= srcDesc
.alignedPtr(rewriter
, loc
);
773 Value srcOffset
= srcDesc
.offset(rewriter
, loc
);
774 Value srcPtr
= rewriter
.create
<LLVM::GEPOp
>(
775 loc
, srcBasePtr
.getType(), elementType
, srcBasePtr
, srcOffset
);
776 MemRefDescriptor
targetDesc(adaptor
.getTarget());
777 Value targetBasePtr
= targetDesc
.alignedPtr(rewriter
, loc
);
778 Value targetOffset
= targetDesc
.offset(rewriter
, loc
);
779 Value targetPtr
= rewriter
.create
<LLVM::GEPOp
>(
780 loc
, targetBasePtr
.getType(), elementType
, targetBasePtr
, targetOffset
);
781 rewriter
.create
<LLVM::MemcpyOp
>(loc
, targetPtr
, srcPtr
, totalSize
,
782 /*isVolatile=*/false);
783 rewriter
.eraseOp(op
);
789 lowerToMemCopyFunctionCall(memref::CopyOp op
, OpAdaptor adaptor
,
790 ConversionPatternRewriter
&rewriter
) const {
791 auto loc
= op
.getLoc();
792 auto srcType
= cast
<BaseMemRefType
>(op
.getSource().getType());
793 auto targetType
= cast
<BaseMemRefType
>(op
.getTarget().getType());
795 // First make sure we have an unranked memref descriptor representation.
796 auto makeUnranked
= [&, this](Value ranked
, MemRefType type
) {
797 auto rank
= rewriter
.create
<LLVM::ConstantOp
>(loc
, getIndexType(),
799 auto *typeConverter
= getTypeConverter();
801 typeConverter
->promoteOneMemRefDescriptor(loc
, ranked
, rewriter
);
804 UnrankedMemRefType::get(type
.getElementType(), type
.getMemorySpace());
805 return UnrankedMemRefDescriptor::pack(
806 rewriter
, loc
, *typeConverter
, unrankedType
, ValueRange
{rank
, ptr
});
809 // Save stack position before promoting descriptors
811 rewriter
.create
<LLVM::StackSaveOp
>(loc
, getVoidPtrType());
813 auto srcMemRefType
= dyn_cast
<MemRefType
>(srcType
);
814 Value unrankedSource
=
815 srcMemRefType
? makeUnranked(adaptor
.getSource(), srcMemRefType
)
816 : adaptor
.getSource();
817 auto targetMemRefType
= dyn_cast
<MemRefType
>(targetType
);
818 Value unrankedTarget
=
819 targetMemRefType
? makeUnranked(adaptor
.getTarget(), targetMemRefType
)
820 : adaptor
.getTarget();
822 // Now promote the unranked descriptors to the stack.
823 auto one
= rewriter
.create
<LLVM::ConstantOp
>(loc
, getIndexType(),
824 rewriter
.getIndexAttr(1));
825 auto promote
= [&](Value desc
) {
826 auto ptrType
= LLVM::LLVMPointerType::get(rewriter
.getContext());
828 rewriter
.create
<LLVM::AllocaOp
>(loc
, ptrType
, desc
.getType(), one
);
829 rewriter
.create
<LLVM::StoreOp
>(loc
, desc
, allocated
);
833 auto sourcePtr
= promote(unrankedSource
);
834 auto targetPtr
= promote(unrankedTarget
);
836 // Derive size from llvm.getelementptr which will account for any
837 // potential alignment
838 auto elemSize
= getSizeInBytes(loc
, srcType
.getElementType(), rewriter
);
839 auto copyFn
= LLVM::lookupOrCreateMemRefCopyFn(
840 op
->getParentOfType
<ModuleOp
>(), getIndexType(), sourcePtr
.getType());
841 rewriter
.create
<LLVM::CallOp
>(loc
, copyFn
,
842 ValueRange
{elemSize
, sourcePtr
, targetPtr
});
844 // Restore stack used for descriptors
845 rewriter
.create
<LLVM::StackRestoreOp
>(loc
, stackSaveOp
);
847 rewriter
.eraseOp(op
);
853 matchAndRewrite(memref::CopyOp op
, OpAdaptor adaptor
,
854 ConversionPatternRewriter
&rewriter
) const override
{
855 auto srcType
= cast
<BaseMemRefType
>(op
.getSource().getType());
856 auto targetType
= cast
<BaseMemRefType
>(op
.getTarget().getType());
858 auto isContiguousMemrefType
= [&](BaseMemRefType type
) {
859 auto memrefType
= dyn_cast
<mlir::MemRefType
>(type
);
860 // We can use memcpy for memrefs if they have an identity layout or are
861 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
862 // special case handled by memrefCopy.
864 (memrefType
.getLayout().isIdentity() ||
865 (memrefType
.hasStaticShape() && memrefType
.getNumElements() > 0 &&
866 memref::isStaticShapeAndContiguousRowMajor(memrefType
)));
869 if (isContiguousMemrefType(srcType
) && isContiguousMemrefType(targetType
))
870 return lowerToMemCopyIntrinsic(op
, adaptor
, rewriter
);
872 return lowerToMemCopyFunctionCall(op
, adaptor
, rewriter
);
876 struct MemorySpaceCastOpLowering
877 : public ConvertOpToLLVMPattern
<memref::MemorySpaceCastOp
> {
878 using ConvertOpToLLVMPattern
<
879 memref::MemorySpaceCastOp
>::ConvertOpToLLVMPattern
;
882 matchAndRewrite(memref::MemorySpaceCastOp op
, OpAdaptor adaptor
,
883 ConversionPatternRewriter
&rewriter
) const override
{
884 Location loc
= op
.getLoc();
886 Type resultType
= op
.getDest().getType();
887 if (auto resultTypeR
= dyn_cast
<MemRefType
>(resultType
)) {
888 auto resultDescType
=
889 cast
<LLVM::LLVMStructType
>(typeConverter
->convertType(resultTypeR
));
890 Type newPtrType
= resultDescType
.getBody()[0];
892 SmallVector
<Value
> descVals
;
893 MemRefDescriptor::unpack(rewriter
, loc
, adaptor
.getSource(), resultTypeR
,
896 rewriter
.create
<LLVM::AddrSpaceCastOp
>(loc
, newPtrType
, descVals
[0]);
898 rewriter
.create
<LLVM::AddrSpaceCastOp
>(loc
, newPtrType
, descVals
[1]);
899 Value result
= MemRefDescriptor::pack(rewriter
, loc
, *getTypeConverter(),
900 resultTypeR
, descVals
);
901 rewriter
.replaceOp(op
, result
);
904 if (auto resultTypeU
= dyn_cast
<UnrankedMemRefType
>(resultType
)) {
905 // Since the type converter won't be doing this for us, get the address
907 auto sourceType
= cast
<UnrankedMemRefType
>(op
.getSource().getType());
908 FailureOr
<unsigned> maybeSourceAddrSpace
=
909 getTypeConverter()->getMemRefAddressSpace(sourceType
);
910 if (failed(maybeSourceAddrSpace
))
911 return rewriter
.notifyMatchFailure(loc
,
912 "non-integer source address space");
913 unsigned sourceAddrSpace
= *maybeSourceAddrSpace
;
914 FailureOr
<unsigned> maybeResultAddrSpace
=
915 getTypeConverter()->getMemRefAddressSpace(resultTypeU
);
916 if (failed(maybeResultAddrSpace
))
917 return rewriter
.notifyMatchFailure(loc
,
918 "non-integer result address space");
919 unsigned resultAddrSpace
= *maybeResultAddrSpace
;
921 UnrankedMemRefDescriptor
sourceDesc(adaptor
.getSource());
922 Value rank
= sourceDesc
.rank(rewriter
, loc
);
923 Value sourceUnderlyingDesc
= sourceDesc
.memRefDescPtr(rewriter
, loc
);
925 // Create and allocate storage for new memref descriptor.
926 auto result
= UnrankedMemRefDescriptor::undef(
927 rewriter
, loc
, typeConverter
->convertType(resultTypeU
));
928 result
.setRank(rewriter
, loc
, rank
);
929 SmallVector
<Value
, 1> sizes
;
930 UnrankedMemRefDescriptor::computeSizes(rewriter
, loc
, *getTypeConverter(),
931 result
, resultAddrSpace
, sizes
);
932 Value resultUnderlyingSize
= sizes
.front();
933 Value resultUnderlyingDesc
= rewriter
.create
<LLVM::AllocaOp
>(
934 loc
, getVoidPtrType(), rewriter
.getI8Type(), resultUnderlyingSize
);
935 result
.setMemRefDescPtr(rewriter
, loc
, resultUnderlyingDesc
);
937 // Copy pointers, performing address space casts.
938 auto sourceElemPtrType
=
939 LLVM::LLVMPointerType::get(rewriter
.getContext(), sourceAddrSpace
);
940 auto resultElemPtrType
=
941 LLVM::LLVMPointerType::get(rewriter
.getContext(), resultAddrSpace
);
943 Value allocatedPtr
= sourceDesc
.allocatedPtr(
944 rewriter
, loc
, sourceUnderlyingDesc
, sourceElemPtrType
);
946 sourceDesc
.alignedPtr(rewriter
, loc
, *getTypeConverter(),
947 sourceUnderlyingDesc
, sourceElemPtrType
);
948 allocatedPtr
= rewriter
.create
<LLVM::AddrSpaceCastOp
>(
949 loc
, resultElemPtrType
, allocatedPtr
);
950 alignedPtr
= rewriter
.create
<LLVM::AddrSpaceCastOp
>(
951 loc
, resultElemPtrType
, alignedPtr
);
953 result
.setAllocatedPtr(rewriter
, loc
, resultUnderlyingDesc
,
954 resultElemPtrType
, allocatedPtr
);
955 result
.setAlignedPtr(rewriter
, loc
, *getTypeConverter(),
956 resultUnderlyingDesc
, resultElemPtrType
, alignedPtr
);
958 // Copy all the index-valued operands.
959 Value sourceIndexVals
=
960 sourceDesc
.offsetBasePtr(rewriter
, loc
, *getTypeConverter(),
961 sourceUnderlyingDesc
, sourceElemPtrType
);
962 Value resultIndexVals
=
963 result
.offsetBasePtr(rewriter
, loc
, *getTypeConverter(),
964 resultUnderlyingDesc
, resultElemPtrType
);
966 int64_t bytesToSkip
=
967 2 * llvm::divideCeil(
968 getTypeConverter()->getPointerBitwidth(resultAddrSpace
), 8);
969 Value bytesToSkipConst
= rewriter
.create
<LLVM::ConstantOp
>(
970 loc
, getIndexType(), rewriter
.getIndexAttr(bytesToSkip
));
971 Value copySize
= rewriter
.create
<LLVM::SubOp
>(
972 loc
, getIndexType(), resultUnderlyingSize
, bytesToSkipConst
);
973 rewriter
.create
<LLVM::MemcpyOp
>(loc
, resultIndexVals
, sourceIndexVals
,
974 copySize
, /*isVolatile=*/false);
976 rewriter
.replaceOp(op
, ValueRange
{result
});
979 return rewriter
.notifyMatchFailure(loc
, "unexpected memref type");
983 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
984 /// memref type. In unranked case, the fields are extracted from the underlying
985 /// ranked descriptor.
986 static void extractPointersAndOffset(Location loc
,
987 ConversionPatternRewriter
&rewriter
,
988 const LLVMTypeConverter
&typeConverter
,
989 Value originalOperand
,
990 Value convertedOperand
,
991 Value
*allocatedPtr
, Value
*alignedPtr
,
992 Value
*offset
= nullptr) {
993 Type operandType
= originalOperand
.getType();
994 if (isa
<MemRefType
>(operandType
)) {
995 MemRefDescriptor
desc(convertedOperand
);
996 *allocatedPtr
= desc
.allocatedPtr(rewriter
, loc
);
997 *alignedPtr
= desc
.alignedPtr(rewriter
, loc
);
998 if (offset
!= nullptr)
999 *offset
= desc
.offset(rewriter
, loc
);
1003 // These will all cause assert()s on unconvertible types.
1004 unsigned memorySpace
= *typeConverter
.getMemRefAddressSpace(
1005 cast
<UnrankedMemRefType
>(operandType
));
1006 auto elementPtrType
=
1007 LLVM::LLVMPointerType::get(rewriter
.getContext(), memorySpace
);
1009 // Extract pointer to the underlying ranked memref descriptor and cast it to
1011 UnrankedMemRefDescriptor
unrankedDesc(convertedOperand
);
1012 Value underlyingDescPtr
= unrankedDesc
.memRefDescPtr(rewriter
, loc
);
1014 *allocatedPtr
= UnrankedMemRefDescriptor::allocatedPtr(
1015 rewriter
, loc
, underlyingDescPtr
, elementPtrType
);
1016 *alignedPtr
= UnrankedMemRefDescriptor::alignedPtr(
1017 rewriter
, loc
, typeConverter
, underlyingDescPtr
, elementPtrType
);
1018 if (offset
!= nullptr) {
1019 *offset
= UnrankedMemRefDescriptor::offset(
1020 rewriter
, loc
, typeConverter
, underlyingDescPtr
, elementPtrType
);
1024 struct MemRefReinterpretCastOpLowering
1025 : public ConvertOpToLLVMPattern
<memref::ReinterpretCastOp
> {
1026 using ConvertOpToLLVMPattern
<
1027 memref::ReinterpretCastOp
>::ConvertOpToLLVMPattern
;
1030 matchAndRewrite(memref::ReinterpretCastOp castOp
, OpAdaptor adaptor
,
1031 ConversionPatternRewriter
&rewriter
) const override
{
1032 Type srcType
= castOp
.getSource().getType();
1035 if (failed(convertSourceMemRefToDescriptor(rewriter
, srcType
, castOp
,
1036 adaptor
, &descriptor
)))
1038 rewriter
.replaceOp(castOp
, {descriptor
});
1043 LogicalResult
convertSourceMemRefToDescriptor(
1044 ConversionPatternRewriter
&rewriter
, Type srcType
,
1045 memref::ReinterpretCastOp castOp
,
1046 memref::ReinterpretCastOp::Adaptor adaptor
, Value
*descriptor
) const {
1047 MemRefType targetMemRefType
=
1048 cast
<MemRefType
>(castOp
.getResult().getType());
1049 auto llvmTargetDescriptorTy
= dyn_cast_or_null
<LLVM::LLVMStructType
>(
1050 typeConverter
->convertType(targetMemRefType
));
1051 if (!llvmTargetDescriptorTy
)
1054 // Create descriptor.
1055 Location loc
= castOp
.getLoc();
1056 auto desc
= MemRefDescriptor::undef(rewriter
, loc
, llvmTargetDescriptorTy
);
1058 // Set allocated and aligned pointers.
1059 Value allocatedPtr
, alignedPtr
;
1060 extractPointersAndOffset(loc
, rewriter
, *getTypeConverter(),
1061 castOp
.getSource(), adaptor
.getSource(),
1062 &allocatedPtr
, &alignedPtr
);
1063 desc
.setAllocatedPtr(rewriter
, loc
, allocatedPtr
);
1064 desc
.setAlignedPtr(rewriter
, loc
, alignedPtr
);
1067 if (castOp
.isDynamicOffset(0))
1068 desc
.setOffset(rewriter
, loc
, adaptor
.getOffsets()[0]);
1070 desc
.setConstantOffset(rewriter
, loc
, castOp
.getStaticOffset(0));
1072 // Set sizes and strides.
1073 unsigned dynSizeId
= 0;
1074 unsigned dynStrideId
= 0;
1075 for (unsigned i
= 0, e
= targetMemRefType
.getRank(); i
< e
; ++i
) {
1076 if (castOp
.isDynamicSize(i
))
1077 desc
.setSize(rewriter
, loc
, i
, adaptor
.getSizes()[dynSizeId
++]);
1079 desc
.setConstantSize(rewriter
, loc
, i
, castOp
.getStaticSize(i
));
1081 if (castOp
.isDynamicStride(i
))
1082 desc
.setStride(rewriter
, loc
, i
, adaptor
.getStrides()[dynStrideId
++]);
1084 desc
.setConstantStride(rewriter
, loc
, i
, castOp
.getStaticStride(i
));
1091 struct MemRefReshapeOpLowering
1092 : public ConvertOpToLLVMPattern
<memref::ReshapeOp
> {
1093 using ConvertOpToLLVMPattern
<memref::ReshapeOp
>::ConvertOpToLLVMPattern
;
1096 matchAndRewrite(memref::ReshapeOp reshapeOp
, OpAdaptor adaptor
,
1097 ConversionPatternRewriter
&rewriter
) const override
{
1098 Type srcType
= reshapeOp
.getSource().getType();
1101 if (failed(convertSourceMemRefToDescriptor(rewriter
, srcType
, reshapeOp
,
1102 adaptor
, &descriptor
)))
1104 rewriter
.replaceOp(reshapeOp
, {descriptor
});
1110 convertSourceMemRefToDescriptor(ConversionPatternRewriter
&rewriter
,
1111 Type srcType
, memref::ReshapeOp reshapeOp
,
1112 memref::ReshapeOp::Adaptor adaptor
,
1113 Value
*descriptor
) const {
1114 auto shapeMemRefType
= cast
<MemRefType
>(reshapeOp
.getShape().getType());
1115 if (shapeMemRefType
.hasStaticShape()) {
1116 MemRefType targetMemRefType
=
1117 cast
<MemRefType
>(reshapeOp
.getResult().getType());
1118 auto llvmTargetDescriptorTy
= dyn_cast_or_null
<LLVM::LLVMStructType
>(
1119 typeConverter
->convertType(targetMemRefType
));
1120 if (!llvmTargetDescriptorTy
)
1123 // Create descriptor.
1124 Location loc
= reshapeOp
.getLoc();
1126 MemRefDescriptor::undef(rewriter
, loc
, llvmTargetDescriptorTy
);
1128 // Set allocated and aligned pointers.
1129 Value allocatedPtr
, alignedPtr
;
1130 extractPointersAndOffset(loc
, rewriter
, *getTypeConverter(),
1131 reshapeOp
.getSource(), adaptor
.getSource(),
1132 &allocatedPtr
, &alignedPtr
);
1133 desc
.setAllocatedPtr(rewriter
, loc
, allocatedPtr
);
1134 desc
.setAlignedPtr(rewriter
, loc
, alignedPtr
);
1136 // Extract the offset and strides from the type.
1138 SmallVector
<int64_t> strides
;
1139 if (failed(getStridesAndOffset(targetMemRefType
, strides
, offset
)))
1140 return rewriter
.notifyMatchFailure(
1141 reshapeOp
, "failed to get stride and offset exprs");
1143 if (!isStaticStrideOrOffset(offset
))
1144 return rewriter
.notifyMatchFailure(reshapeOp
,
1145 "dynamic offset is unsupported");
1147 desc
.setConstantOffset(rewriter
, loc
, offset
);
1149 assert(targetMemRefType
.getLayout().isIdentity() &&
1150 "Identity layout map is a precondition of a valid reshape op");
1152 Type indexType
= getIndexType();
1153 Value stride
= nullptr;
1154 int64_t targetRank
= targetMemRefType
.getRank();
1155 for (auto i
: llvm::reverse(llvm::seq
<int64_t>(0, targetRank
))) {
1156 if (!ShapedType::isDynamic(strides
[i
])) {
1157 // If the stride for this dimension is dynamic, then use the product
1158 // of the sizes of the inner dimensions.
1160 createIndexAttrConstant(rewriter
, loc
, indexType
, strides
[i
]);
1161 } else if (!stride
) {
1162 // `stride` is null only in the first iteration of the loop. However,
1163 // since the target memref has an identity layout, we can safely set
1164 // the innermost stride to 1.
1165 stride
= createIndexAttrConstant(rewriter
, loc
, indexType
, 1);
1169 // If the size of this dimension is dynamic, then load it at runtime
1170 // from the shape operand.
1171 if (!targetMemRefType
.isDynamicDim(i
)) {
1172 dimSize
= createIndexAttrConstant(rewriter
, loc
, indexType
,
1173 targetMemRefType
.getDimSize(i
));
1175 Value shapeOp
= reshapeOp
.getShape();
1176 Value index
= createIndexAttrConstant(rewriter
, loc
, indexType
, i
);
1177 dimSize
= rewriter
.create
<memref::LoadOp
>(loc
, shapeOp
, index
);
1178 Type indexType
= getIndexType();
1179 if (dimSize
.getType() != indexType
)
1180 dimSize
= typeConverter
->materializeTargetConversion(
1181 rewriter
, loc
, indexType
, dimSize
);
1182 assert(dimSize
&& "Invalid memref element type");
1185 desc
.setSize(rewriter
, loc
, i
, dimSize
);
1186 desc
.setStride(rewriter
, loc
, i
, stride
);
1188 // Prepare the stride value for the next dimension.
1189 stride
= rewriter
.create
<LLVM::MulOp
>(loc
, stride
, dimSize
);
1196 // The shape is a rank-1 tensor with unknown length.
1197 Location loc
= reshapeOp
.getLoc();
1198 MemRefDescriptor
shapeDesc(adaptor
.getShape());
1199 Value resultRank
= shapeDesc
.size(rewriter
, loc
, 0);
1201 // Extract address space and element type.
1202 auto targetType
= cast
<UnrankedMemRefType
>(reshapeOp
.getResult().getType());
1203 unsigned addressSpace
=
1204 *getTypeConverter()->getMemRefAddressSpace(targetType
);
1206 // Create the unranked memref descriptor that holds the ranked one. The
1207 // inner descriptor is allocated on stack.
1208 auto targetDesc
= UnrankedMemRefDescriptor::undef(
1209 rewriter
, loc
, typeConverter
->convertType(targetType
));
1210 targetDesc
.setRank(rewriter
, loc
, resultRank
);
1211 SmallVector
<Value
, 4> sizes
;
1212 UnrankedMemRefDescriptor::computeSizes(rewriter
, loc
, *getTypeConverter(),
1213 targetDesc
, addressSpace
, sizes
);
1214 Value underlyingDescPtr
= rewriter
.create
<LLVM::AllocaOp
>(
1215 loc
, getVoidPtrType(), IntegerType::get(getContext(), 8),
1217 targetDesc
.setMemRefDescPtr(rewriter
, loc
, underlyingDescPtr
);
1219 // Extract pointers and offset from the source memref.
1220 Value allocatedPtr
, alignedPtr
, offset
;
1221 extractPointersAndOffset(loc
, rewriter
, *getTypeConverter(),
1222 reshapeOp
.getSource(), adaptor
.getSource(),
1223 &allocatedPtr
, &alignedPtr
, &offset
);
1225 // Set pointers and offset.
1226 auto elementPtrType
=
1227 LLVM::LLVMPointerType::get(rewriter
.getContext(), addressSpace
);
1229 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter
, loc
, underlyingDescPtr
,
1230 elementPtrType
, allocatedPtr
);
1231 UnrankedMemRefDescriptor::setAlignedPtr(rewriter
, loc
, *getTypeConverter(),
1232 underlyingDescPtr
, elementPtrType
,
1234 UnrankedMemRefDescriptor::setOffset(rewriter
, loc
, *getTypeConverter(),
1235 underlyingDescPtr
, elementPtrType
,
1238 // Use the offset pointer as base for further addressing. Copy over the new
1239 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1240 Value targetSizesBase
= UnrankedMemRefDescriptor::sizeBasePtr(
1241 rewriter
, loc
, *getTypeConverter(), underlyingDescPtr
, elementPtrType
);
1242 Value targetStridesBase
= UnrankedMemRefDescriptor::strideBasePtr(
1243 rewriter
, loc
, *getTypeConverter(), targetSizesBase
, resultRank
);
1244 Value shapeOperandPtr
= shapeDesc
.alignedPtr(rewriter
, loc
);
1245 Value oneIndex
= createIndexAttrConstant(rewriter
, loc
, getIndexType(), 1);
1246 Value resultRankMinusOne
=
1247 rewriter
.create
<LLVM::SubOp
>(loc
, resultRank
, oneIndex
);
1249 Block
*initBlock
= rewriter
.getInsertionBlock();
1250 Type indexType
= getTypeConverter()->getIndexType();
1251 Block::iterator remainingOpsIt
= std::next(rewriter
.getInsertionPoint());
1253 Block
*condBlock
= rewriter
.createBlock(initBlock
->getParent(), {},
1254 {indexType
, indexType
}, {loc
, loc
});
1256 // Move the remaining initBlock ops to condBlock.
1257 Block
*remainingBlock
= rewriter
.splitBlock(initBlock
, remainingOpsIt
);
1258 rewriter
.mergeBlocks(remainingBlock
, condBlock
, ValueRange());
1260 rewriter
.setInsertionPointToEnd(initBlock
);
1261 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange({resultRankMinusOne
, oneIndex
}),
1263 rewriter
.setInsertionPointToStart(condBlock
);
1264 Value indexArg
= condBlock
->getArgument(0);
1265 Value strideArg
= condBlock
->getArgument(1);
1267 Value zeroIndex
= createIndexAttrConstant(rewriter
, loc
, indexType
, 0);
1268 Value pred
= rewriter
.create
<LLVM::ICmpOp
>(
1269 loc
, IntegerType::get(rewriter
.getContext(), 1),
1270 LLVM::ICmpPredicate::sge
, indexArg
, zeroIndex
);
1273 rewriter
.splitBlock(condBlock
, rewriter
.getInsertionPoint());
1274 rewriter
.setInsertionPointToStart(bodyBlock
);
1276 // Copy size from shape to descriptor.
1277 auto llvmIndexPtrType
= LLVM::LLVMPointerType::get(rewriter
.getContext());
1278 Value sizeLoadGep
= rewriter
.create
<LLVM::GEPOp
>(
1279 loc
, llvmIndexPtrType
,
1280 typeConverter
->convertType(shapeMemRefType
.getElementType()),
1281 shapeOperandPtr
, indexArg
);
1282 Value size
= rewriter
.create
<LLVM::LoadOp
>(loc
, indexType
, sizeLoadGep
);
1283 UnrankedMemRefDescriptor::setSize(rewriter
, loc
, *getTypeConverter(),
1284 targetSizesBase
, indexArg
, size
);
1286 // Write stride value and compute next one.
1287 UnrankedMemRefDescriptor::setStride(rewriter
, loc
, *getTypeConverter(),
1288 targetStridesBase
, indexArg
, strideArg
);
1289 Value nextStride
= rewriter
.create
<LLVM::MulOp
>(loc
, strideArg
, size
);
1291 // Decrement loop counter and branch back.
1292 Value decrement
= rewriter
.create
<LLVM::SubOp
>(loc
, indexArg
, oneIndex
);
1293 rewriter
.create
<LLVM::BrOp
>(loc
, ValueRange({decrement
, nextStride
}),
1297 rewriter
.splitBlock(bodyBlock
, rewriter
.getInsertionPoint());
1299 // Hook up the cond exit to the remainder.
1300 rewriter
.setInsertionPointToEnd(condBlock
);
1301 rewriter
.create
<LLVM::CondBrOp
>(loc
, pred
, bodyBlock
, std::nullopt
,
1302 remainder
, std::nullopt
);
1304 // Reset position to beginning of new remainder block.
1305 rewriter
.setInsertionPointToStart(remainder
);
1307 *descriptor
= targetDesc
;
1312 /// RessociatingReshapeOp must be expanded before we reach this stage.
1313 /// Report that information.
1314 template <typename ReshapeOp
>
1315 class ReassociatingReshapeOpConversion
1316 : public ConvertOpToLLVMPattern
<ReshapeOp
> {
1318 using ConvertOpToLLVMPattern
<ReshapeOp
>::ConvertOpToLLVMPattern
;
1319 using ReshapeOpAdaptor
= typename
ReshapeOp::Adaptor
;
1322 matchAndRewrite(ReshapeOp reshapeOp
, typename
ReshapeOp::Adaptor adaptor
,
1323 ConversionPatternRewriter
&rewriter
) const override
{
1324 return rewriter
.notifyMatchFailure(
1326 "reassociation operations should have been expanded beforehand");
1330 /// Subviews must be expanded before we reach this stage.
1331 /// Report that information.
1332 struct SubViewOpLowering
: public ConvertOpToLLVMPattern
<memref::SubViewOp
> {
1333 using ConvertOpToLLVMPattern
<memref::SubViewOp
>::ConvertOpToLLVMPattern
;
1336 matchAndRewrite(memref::SubViewOp subViewOp
, OpAdaptor adaptor
,
1337 ConversionPatternRewriter
&rewriter
) const override
{
1338 return rewriter
.notifyMatchFailure(
1339 subViewOp
, "subview operations should have been expanded beforehand");
1343 /// Conversion pattern that transforms a transpose op into:
1344 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1345 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1346 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1347 /// and stride. Size and stride are permutations of the original values.
1348 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1349 /// The transpose op is replaced by the alloca'ed pointer.
1350 class TransposeOpLowering
: public ConvertOpToLLVMPattern
<memref::TransposeOp
> {
1352 using ConvertOpToLLVMPattern
<memref::TransposeOp
>::ConvertOpToLLVMPattern
;
1355 matchAndRewrite(memref::TransposeOp transposeOp
, OpAdaptor adaptor
,
1356 ConversionPatternRewriter
&rewriter
) const override
{
1357 auto loc
= transposeOp
.getLoc();
1358 MemRefDescriptor
viewMemRef(adaptor
.getIn());
1360 // No permutation, early exit.
1361 if (transposeOp
.getPermutation().isIdentity())
1362 return rewriter
.replaceOp(transposeOp
, {viewMemRef
}), success();
1364 auto targetMemRef
= MemRefDescriptor::undef(
1366 typeConverter
->convertType(transposeOp
.getIn().getType()));
1368 // Copy the base and aligned pointers from the old descriptor to the new
1370 targetMemRef
.setAllocatedPtr(rewriter
, loc
,
1371 viewMemRef
.allocatedPtr(rewriter
, loc
));
1372 targetMemRef
.setAlignedPtr(rewriter
, loc
,
1373 viewMemRef
.alignedPtr(rewriter
, loc
));
1375 // Copy the offset pointer from the old descriptor to the new one.
1376 targetMemRef
.setOffset(rewriter
, loc
, viewMemRef
.offset(rewriter
, loc
));
1378 // Iterate over the dimensions and apply size/stride permutation:
1379 // When enumerating the results of the permutation map, the enumeration
1380 // index is the index into the target dimensions and the DimExpr points to
1381 // the dimension of the source memref.
1382 for (const auto &en
:
1383 llvm::enumerate(transposeOp
.getPermutation().getResults())) {
1384 int targetPos
= en
.index();
1385 int sourcePos
= cast
<AffineDimExpr
>(en
.value()).getPosition();
1386 targetMemRef
.setSize(rewriter
, loc
, targetPos
,
1387 viewMemRef
.size(rewriter
, loc
, sourcePos
));
1388 targetMemRef
.setStride(rewriter
, loc
, targetPos
,
1389 viewMemRef
.stride(rewriter
, loc
, sourcePos
));
1392 rewriter
.replaceOp(transposeOp
, {targetMemRef
});
1397 /// Conversion pattern that transforms an op into:
1398 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1399 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1401 /// The view op is replaced by the descriptor.
1402 struct ViewOpLowering
: public ConvertOpToLLVMPattern
<memref::ViewOp
> {
1403 using ConvertOpToLLVMPattern
<memref::ViewOp
>::ConvertOpToLLVMPattern
;
1405 // Build and return the value for the idx^th shape dimension, either by
1406 // returning the constant shape dimension or counting the proper dynamic size.
1407 Value
getSize(ConversionPatternRewriter
&rewriter
, Location loc
,
1408 ArrayRef
<int64_t> shape
, ValueRange dynamicSizes
, unsigned idx
,
1409 Type indexType
) const {
1410 assert(idx
< shape
.size());
1411 if (!ShapedType::isDynamic(shape
[idx
]))
1412 return createIndexAttrConstant(rewriter
, loc
, indexType
, shape
[idx
]);
1413 // Count the number of dynamic dims in range [0, idx]
1415 llvm::count_if(shape
.take_front(idx
), ShapedType::isDynamic
);
1416 return dynamicSizes
[nDynamic
];
1419 // Build and return the idx^th stride, either by returning the constant stride
1420 // or by computing the dynamic stride from the current `runningStride` and
1421 // `nextSize`. The caller should keep a running stride and update it with the
1422 // result returned by this function.
1423 Value
getStride(ConversionPatternRewriter
&rewriter
, Location loc
,
1424 ArrayRef
<int64_t> strides
, Value nextSize
,
1425 Value runningStride
, unsigned idx
, Type indexType
) const {
1426 assert(idx
< strides
.size());
1427 if (!ShapedType::isDynamic(strides
[idx
]))
1428 return createIndexAttrConstant(rewriter
, loc
, indexType
, strides
[idx
]);
1430 return runningStride
1431 ? rewriter
.create
<LLVM::MulOp
>(loc
, runningStride
, nextSize
)
1433 assert(!runningStride
);
1434 return createIndexAttrConstant(rewriter
, loc
, indexType
, 1);
1438 matchAndRewrite(memref::ViewOp viewOp
, OpAdaptor adaptor
,
1439 ConversionPatternRewriter
&rewriter
) const override
{
1440 auto loc
= viewOp
.getLoc();
1442 auto viewMemRefType
= viewOp
.getType();
1443 auto targetElementTy
=
1444 typeConverter
->convertType(viewMemRefType
.getElementType());
1445 auto targetDescTy
= typeConverter
->convertType(viewMemRefType
);
1446 if (!targetDescTy
|| !targetElementTy
||
1447 !LLVM::isCompatibleType(targetElementTy
) ||
1448 !LLVM::isCompatibleType(targetDescTy
))
1449 return viewOp
.emitWarning("Target descriptor type not converted to LLVM"),
1453 SmallVector
<int64_t, 4> strides
;
1454 auto successStrides
= getStridesAndOffset(viewMemRefType
, strides
, offset
);
1455 if (failed(successStrides
))
1456 return viewOp
.emitWarning("cannot cast to non-strided shape"), failure();
1457 assert(offset
== 0 && "expected offset to be 0");
1459 // Target memref must be contiguous in memory (innermost stride is 1), or
1460 // empty (special case when at least one of the memref dimensions is 0).
1461 if (!strides
.empty() && (strides
.back() != 1 && strides
.back() != 0))
1462 return viewOp
.emitWarning("cannot cast to non-contiguous shape"),
1465 // Create the descriptor.
1466 MemRefDescriptor
sourceMemRef(adaptor
.getSource());
1467 auto targetMemRef
= MemRefDescriptor::undef(rewriter
, loc
, targetDescTy
);
1469 // Field 1: Copy the allocated pointer, used for malloc/free.
1470 Value allocatedPtr
= sourceMemRef
.allocatedPtr(rewriter
, loc
);
1471 auto srcMemRefType
= cast
<MemRefType
>(viewOp
.getSource().getType());
1472 targetMemRef
.setAllocatedPtr(rewriter
, loc
, allocatedPtr
);
1474 // Field 2: Copy the actual aligned pointer to payload.
1475 Value alignedPtr
= sourceMemRef
.alignedPtr(rewriter
, loc
);
1476 alignedPtr
= rewriter
.create
<LLVM::GEPOp
>(
1477 loc
, alignedPtr
.getType(),
1478 typeConverter
->convertType(srcMemRefType
.getElementType()), alignedPtr
,
1479 adaptor
.getByteShift());
1481 targetMemRef
.setAlignedPtr(rewriter
, loc
, alignedPtr
);
1483 Type indexType
= getIndexType();
1484 // Field 3: The offset in the resulting type must be 0. This is
1485 // because of the type change: an offset on srcType* may not be
1486 // expressible as an offset on dstType*.
1487 targetMemRef
.setOffset(
1489 createIndexAttrConstant(rewriter
, loc
, indexType
, offset
));
1491 // Early exit for 0-D corner case.
1492 if (viewMemRefType
.getRank() == 0)
1493 return rewriter
.replaceOp(viewOp
, {targetMemRef
}), success();
1495 // Fields 4 and 5: Update sizes and strides.
1496 Value stride
= nullptr, nextSize
= nullptr;
1497 for (int i
= viewMemRefType
.getRank() - 1; i
>= 0; --i
) {
1499 Value size
= getSize(rewriter
, loc
, viewMemRefType
.getShape(),
1500 adaptor
.getSizes(), i
, indexType
);
1501 targetMemRef
.setSize(rewriter
, loc
, i
, size
);
1504 getStride(rewriter
, loc
, strides
, nextSize
, stride
, i
, indexType
);
1505 targetMemRef
.setStride(rewriter
, loc
, i
, stride
);
1509 rewriter
.replaceOp(viewOp
, {targetMemRef
});
1514 //===----------------------------------------------------------------------===//
1515 // AtomicRMWOpLowering
1516 //===----------------------------------------------------------------------===//
1518 /// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1519 /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1520 static std::optional
<LLVM::AtomicBinOp
>
1521 matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp
) {
1522 switch (atomicOp
.getKind()) {
1523 case arith::AtomicRMWKind::addf
:
1524 return LLVM::AtomicBinOp::fadd
;
1525 case arith::AtomicRMWKind::addi
:
1526 return LLVM::AtomicBinOp::add
;
1527 case arith::AtomicRMWKind::assign
:
1528 return LLVM::AtomicBinOp::xchg
;
1529 case arith::AtomicRMWKind::maximumf
:
1530 return LLVM::AtomicBinOp::fmax
;
1531 case arith::AtomicRMWKind::maxs
:
1532 return LLVM::AtomicBinOp::max
;
1533 case arith::AtomicRMWKind::maxu
:
1534 return LLVM::AtomicBinOp::umax
;
1535 case arith::AtomicRMWKind::minimumf
:
1536 return LLVM::AtomicBinOp::fmin
;
1537 case arith::AtomicRMWKind::mins
:
1538 return LLVM::AtomicBinOp::min
;
1539 case arith::AtomicRMWKind::minu
:
1540 return LLVM::AtomicBinOp::umin
;
1541 case arith::AtomicRMWKind::ori
:
1542 return LLVM::AtomicBinOp::_or
;
1543 case arith::AtomicRMWKind::andi
:
1544 return LLVM::AtomicBinOp::_and
;
1546 return std::nullopt
;
1548 llvm_unreachable("Invalid AtomicRMWKind");
1551 struct AtomicRMWOpLowering
: public LoadStoreOpLowering
<memref::AtomicRMWOp
> {
1555 matchAndRewrite(memref::AtomicRMWOp atomicOp
, OpAdaptor adaptor
,
1556 ConversionPatternRewriter
&rewriter
) const override
{
1557 auto maybeKind
= matchSimpleAtomicOp(atomicOp
);
1560 auto memRefType
= atomicOp
.getMemRefType();
1561 SmallVector
<int64_t> strides
;
1563 if (failed(getStridesAndOffset(memRefType
, strides
, offset
)))
1566 getStridedElementPtr(atomicOp
.getLoc(), memRefType
, adaptor
.getMemref(),
1567 adaptor
.getIndices(), rewriter
);
1568 rewriter
.replaceOpWithNewOp
<LLVM::AtomicRMWOp
>(
1569 atomicOp
, *maybeKind
, dataPtr
, adaptor
.getValue(),
1570 LLVM::AtomicOrdering::acq_rel
);
1575 /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1576 class ConvertExtractAlignedPointerAsIndex
1577 : public ConvertOpToLLVMPattern
<memref::ExtractAlignedPointerAsIndexOp
> {
1579 using ConvertOpToLLVMPattern
<
1580 memref::ExtractAlignedPointerAsIndexOp
>::ConvertOpToLLVMPattern
;
1583 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp
,
1585 ConversionPatternRewriter
&rewriter
) const override
{
1586 BaseMemRefType sourceTy
= extractOp
.getSource().getType();
1589 if (sourceTy
.hasRank()) {
1590 MemRefDescriptor
desc(adaptor
.getSource());
1591 alignedPtr
= desc
.alignedPtr(rewriter
, extractOp
->getLoc());
1593 auto elementPtrTy
= LLVM::LLVMPointerType::get(
1594 rewriter
.getContext(), sourceTy
.getMemorySpaceAsInt());
1596 UnrankedMemRefDescriptor
desc(adaptor
.getSource());
1597 Value descPtr
= desc
.memRefDescPtr(rewriter
, extractOp
->getLoc());
1599 alignedPtr
= UnrankedMemRefDescriptor::alignedPtr(
1600 rewriter
, extractOp
->getLoc(), *getTypeConverter(), descPtr
,
1604 rewriter
.replaceOpWithNewOp
<LLVM::PtrToIntOp
>(
1605 extractOp
, getTypeConverter()->getIndexType(), alignedPtr
);
1610 /// Materialize the MemRef descriptor represented by the results of
1611 /// ExtractStridedMetadataOp.
1612 class ExtractStridedMetadataOpLowering
1613 : public ConvertOpToLLVMPattern
<memref::ExtractStridedMetadataOp
> {
1615 using ConvertOpToLLVMPattern
<
1616 memref::ExtractStridedMetadataOp
>::ConvertOpToLLVMPattern
;
1619 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp
,
1621 ConversionPatternRewriter
&rewriter
) const override
{
1623 if (!LLVM::isCompatibleType(adaptor
.getOperands().front().getType()))
1626 // Create the descriptor.
1627 MemRefDescriptor
sourceMemRef(adaptor
.getSource());
1628 Location loc
= extractStridedMetadataOp
.getLoc();
1629 Value source
= extractStridedMetadataOp
.getSource();
1631 auto sourceMemRefType
= cast
<MemRefType
>(source
.getType());
1632 int64_t rank
= sourceMemRefType
.getRank();
1633 SmallVector
<Value
> results
;
1634 results
.reserve(2 + rank
* 2);
1637 Value baseBuffer
= sourceMemRef
.allocatedPtr(rewriter
, loc
);
1638 Value alignedBuffer
= sourceMemRef
.alignedPtr(rewriter
, loc
);
1639 MemRefDescriptor dstMemRef
= MemRefDescriptor::fromStaticShape(
1640 rewriter
, loc
, *getTypeConverter(),
1641 cast
<MemRefType
>(extractStridedMetadataOp
.getBaseBuffer().getType()),
1642 baseBuffer
, alignedBuffer
);
1643 results
.push_back((Value
)dstMemRef
);
1646 results
.push_back(sourceMemRef
.offset(rewriter
, loc
));
1649 for (unsigned i
= 0; i
< rank
; ++i
)
1650 results
.push_back(sourceMemRef
.size(rewriter
, loc
, i
));
1652 for (unsigned i
= 0; i
< rank
; ++i
)
1653 results
.push_back(sourceMemRef
.stride(rewriter
, loc
, i
));
1655 rewriter
.replaceOp(extractStridedMetadataOp
, results
);
1662 void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1663 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
) {
1667 AllocaScopeOpLowering
,
1668 AtomicRMWOpLowering
,
1669 AssumeAlignmentOpLowering
,
1670 ConvertExtractAlignedPointerAsIndex
,
1672 ExtractStridedMetadataOpLowering
,
1673 GenericAtomicRMWOpLowering
,
1674 GlobalMemrefOpLowering
,
1675 GetGlobalMemrefOpLowering
,
1677 MemRefCastOpLowering
,
1678 MemRefCopyOpLowering
,
1679 MemorySpaceCastOpLowering
,
1680 MemRefReinterpretCastOpLowering
,
1681 MemRefReshapeOpLowering
,
1684 ReassociatingReshapeOpConversion
<memref::ExpandShapeOp
>,
1685 ReassociatingReshapeOpConversion
<memref::CollapseShapeOp
>,
1688 TransposeOpLowering
,
1689 ViewOpLowering
>(converter
);
1691 auto allocLowering
= converter
.getOptions().allocLowering
;
1692 if (allocLowering
== LowerToLLVMOptions::AllocLowering::AlignedAlloc
)
1693 patterns
.add
<AlignedAllocOpLowering
, DeallocOpLowering
>(converter
);
1694 else if (allocLowering
== LowerToLLVMOptions::AllocLowering::Malloc
)
1695 patterns
.add
<AllocOpLowering
, DeallocOpLowering
>(converter
);
1699 struct FinalizeMemRefToLLVMConversionPass
1700 : public impl::FinalizeMemRefToLLVMConversionPassBase
<
1701 FinalizeMemRefToLLVMConversionPass
> {
1702 using FinalizeMemRefToLLVMConversionPassBase::
1703 FinalizeMemRefToLLVMConversionPassBase
;
1705 void runOnOperation() override
{
1706 Operation
*op
= getOperation();
1707 const auto &dataLayoutAnalysis
= getAnalysis
<DataLayoutAnalysis
>();
1708 LowerToLLVMOptions
options(&getContext(),
1709 dataLayoutAnalysis
.getAtOrAbove(op
));
1710 options
.allocLowering
=
1711 (useAlignedAlloc
? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1712 : LowerToLLVMOptions::AllocLowering::Malloc
);
1714 options
.useGenericFunctions
= useGenericFunctions
;
1716 if (indexBitwidth
!= kDeriveIndexBitwidthFromDataLayout
)
1717 options
.overrideIndexBitwidth(indexBitwidth
);
1719 LLVMTypeConverter
typeConverter(&getContext(), options
,
1720 &dataLayoutAnalysis
);
1721 RewritePatternSet
patterns(&getContext());
1722 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter
, patterns
);
1723 LLVMConversionTarget
target(getContext());
1724 target
.addLegalOp
<func::FuncOp
>();
1725 if (failed(applyPartialConversion(op
, target
, std::move(patterns
))))
1726 signalPassFailure();
1730 /// Implement the interface to convert MemRef to LLVM.
1731 struct MemRefToLLVMDialectInterface
: public ConvertToLLVMPatternInterface
{
1732 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface
;
1733 void loadDependentDialects(MLIRContext
*context
) const final
{
1734 context
->loadDialect
<LLVM::LLVMDialect
>();
1737 /// Hook for derived dialect interface to provide conversion patterns
1738 /// and mark dialect legal for the conversion target.
1739 void populateConvertToLLVMConversionPatterns(
1740 ConversionTarget
&target
, LLVMTypeConverter
&typeConverter
,
1741 RewritePatternSet
&patterns
) const final
{
1742 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter
, patterns
);
1748 void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry
®istry
) {
1749 registry
.addExtension(+[](MLIRContext
*ctx
, memref::MemRefDialect
*dialect
) {
1750 dialect
->addInterfaces
<MemRefToLLVMDialectInterface
>();