1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
24 #include "llvm/Support/Debug.h"
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
32 //===----------------------------------------------------------------------===//
34 //===----------------------------------------------------------------------===//
36 /// Returns the offset of the value in `targetBits` representation.
38 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39 /// It's assumed to be non-negative.
41 /// When accessing an element in the array treating as having elements of
42 /// `targetBits`, multiple values are loaded in the same time. The method
43 /// returns the offset where the `srcIdx` locates in the value. For example, if
44 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
46 /// element has 8 bits.
47 static Value
getOffsetForBitwidth(Location loc
, Value srcIdx
, int sourceBits
,
48 int targetBits
, OpBuilder
&builder
) {
49 assert(targetBits
% sourceBits
== 0);
50 Type type
= srcIdx
.getType();
51 IntegerAttr idxAttr
= builder
.getIntegerAttr(type
, targetBits
/ sourceBits
);
52 auto idx
= builder
.createOrFold
<spirv::ConstantOp
>(loc
, type
, idxAttr
);
53 IntegerAttr srcBitsAttr
= builder
.getIntegerAttr(type
, sourceBits
);
55 builder
.createOrFold
<spirv::ConstantOp
>(loc
, type
, srcBitsAttr
);
56 auto m
= builder
.createOrFold
<spirv::UModOp
>(loc
, srcIdx
, idx
);
57 return builder
.createOrFold
<spirv::IMulOp
>(loc
, type
, m
, srcBitsValue
);
60 /// Returns an adjusted spirv::AccessChainOp. Based on the
61 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62 /// supported. During conversion if a memref of an unsupported type is used,
63 /// load/stores to this memref need to be modified to use a supported higher
64 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
65 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66 /// the bits needed. The extraction of the actual bits needed are handled
67 /// separately. Note that this only works for a 1-D tensor.
69 adjustAccessChainForBitwidth(const SPIRVTypeConverter
&typeConverter
,
70 spirv::AccessChainOp op
, int sourceBits
,
71 int targetBits
, OpBuilder
&builder
) {
72 assert(targetBits
% sourceBits
== 0);
73 const auto loc
= op
.getLoc();
74 Value lastDim
= op
->getOperand(op
.getNumOperands() - 1);
75 Type type
= lastDim
.getType();
76 IntegerAttr attr
= builder
.getIntegerAttr(type
, targetBits
/ sourceBits
);
77 auto idx
= builder
.createOrFold
<spirv::ConstantOp
>(loc
, type
, attr
);
78 auto indices
= llvm::to_vector
<4>(op
.getIndices());
79 // There are two elements if this is a 1-D tensor.
80 assert(indices
.size() == 2);
81 indices
.back() = builder
.createOrFold
<spirv::SDivOp
>(loc
, lastDim
, idx
);
82 Type t
= typeConverter
.convertType(op
.getComponentPtr().getType());
83 return builder
.create
<spirv::AccessChainOp
>(loc
, t
, op
.getBasePtr(), indices
);
86 /// Casts the given `srcBool` into an integer of `dstType`.
87 static Value
castBoolToIntN(Location loc
, Value srcBool
, Type dstType
,
89 assert(srcBool
.getType().isInteger(1));
90 if (dstType
.isInteger(1))
92 Value zero
= spirv::ConstantOp::getZero(dstType
, loc
, builder
);
93 Value one
= spirv::ConstantOp::getOne(dstType
, loc
, builder
);
94 return builder
.createOrFold
<spirv::SelectOp
>(loc
, dstType
, srcBool
, one
,
98 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
99 /// to the type destination type, and masked.
100 static Value
shiftValue(Location loc
, Value value
, Value offset
, Value mask
,
101 OpBuilder
&builder
) {
102 IntegerType dstType
= cast
<IntegerType
>(mask
.getType());
103 int targetBits
= static_cast<int>(dstType
.getWidth());
104 int valueBits
= value
.getType().getIntOrFloatBitWidth();
105 assert(valueBits
<= targetBits
);
107 if (valueBits
== 1) {
108 value
= castBoolToIntN(loc
, value
, dstType
, builder
);
110 if (valueBits
< targetBits
) {
111 value
= builder
.create
<spirv::UConvertOp
>(
112 loc
, builder
.getIntegerType(targetBits
), value
);
115 value
= builder
.createOrFold
<spirv::BitwiseAndOp
>(loc
, value
, mask
);
117 return builder
.createOrFold
<spirv::ShiftLeftLogicalOp
>(loc
, value
.getType(),
121 /// Returns true if the allocations of memref `type` generated from `allocOp`
122 /// can be lowered to SPIR-V.
123 static bool isAllocationSupported(Operation
*allocOp
, MemRefType type
) {
124 if (isa
<memref::AllocOp
, memref::DeallocOp
>(allocOp
)) {
125 auto sc
= dyn_cast_or_null
<spirv::StorageClassAttr
>(type
.getMemorySpace());
126 if (!sc
|| sc
.getValue() != spirv::StorageClass::Workgroup
)
128 } else if (isa
<memref::AllocaOp
>(allocOp
)) {
129 auto sc
= dyn_cast_or_null
<spirv::StorageClassAttr
>(type
.getMemorySpace());
130 if (!sc
|| sc
.getValue() != spirv::StorageClass::Function
)
136 // Currently only support static shape and int or float or vector of int or
137 // float element type.
138 if (!type
.hasStaticShape())
141 Type elementType
= type
.getElementType();
142 if (auto vecType
= dyn_cast
<VectorType
>(elementType
))
143 elementType
= vecType
.getElementType();
144 return elementType
.isIntOrFloat();
147 /// Returns the scope to use for atomic operations use for emulating store
148 /// operations of unsupported integer bitwidths, based on the memref
149 /// type. Returns std::nullopt on failure.
150 static std::optional
<spirv::Scope
> getAtomicOpScope(MemRefType type
) {
151 auto sc
= dyn_cast_or_null
<spirv::StorageClassAttr
>(type
.getMemorySpace());
152 switch (sc
.getValue()) {
153 case spirv::StorageClass::StorageBuffer
:
154 return spirv::Scope::Device
;
155 case spirv::StorageClass::Workgroup
:
156 return spirv::Scope::Workgroup
;
163 /// Casts the given `srcInt` into a boolean value.
164 static Value
castIntNToBool(Location loc
, Value srcInt
, OpBuilder
&builder
) {
165 if (srcInt
.getType().isInteger(1))
168 auto one
= spirv::ConstantOp::getZero(srcInt
.getType(), loc
, builder
);
169 return builder
.createOrFold
<spirv::INotEqualOp
>(loc
, srcInt
, one
);
172 //===----------------------------------------------------------------------===//
173 // Operation conversion
174 //===----------------------------------------------------------------------===//
176 // Note that DRR cannot be used for the patterns in this file: we may need to
177 // convert type along the way, which requires ConversionPattern. DRR generates
178 // normal RewritePattern.
182 /// Converts memref.alloca to SPIR-V Function variables.
183 class AllocaOpPattern final
: public OpConversionPattern
<memref::AllocaOp
> {
185 using OpConversionPattern
<memref::AllocaOp
>::OpConversionPattern
;
188 matchAndRewrite(memref::AllocaOp allocaOp
, OpAdaptor adaptor
,
189 ConversionPatternRewriter
&rewriter
) const override
;
192 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
193 /// to Workgroup memory when the size is constant. Note that this pattern needs
194 /// to be applied in a pass that runs at least at spirv.module scope since it
195 /// wil ladd global variables into the spirv.module.
196 class AllocOpPattern final
: public OpConversionPattern
<memref::AllocOp
> {
198 using OpConversionPattern
<memref::AllocOp
>::OpConversionPattern
;
201 matchAndRewrite(memref::AllocOp operation
, OpAdaptor adaptor
,
202 ConversionPatternRewriter
&rewriter
) const override
;
205 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
206 class AtomicRMWOpPattern final
207 : public OpConversionPattern
<memref::AtomicRMWOp
> {
209 using OpConversionPattern
<memref::AtomicRMWOp
>::OpConversionPattern
;
212 matchAndRewrite(memref::AtomicRMWOp atomicOp
, OpAdaptor adaptor
,
213 ConversionPatternRewriter
&rewriter
) const override
;
216 /// Removed a deallocation if it is a supported allocation. Currently only
217 /// removes deallocation if the memory space is workgroup memory.
218 class DeallocOpPattern final
: public OpConversionPattern
<memref::DeallocOp
> {
220 using OpConversionPattern
<memref::DeallocOp
>::OpConversionPattern
;
223 matchAndRewrite(memref::DeallocOp operation
, OpAdaptor adaptor
,
224 ConversionPatternRewriter
&rewriter
) const override
;
227 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
228 class IntLoadOpPattern final
: public OpConversionPattern
<memref::LoadOp
> {
230 using OpConversionPattern
<memref::LoadOp
>::OpConversionPattern
;
233 matchAndRewrite(memref::LoadOp loadOp
, OpAdaptor adaptor
,
234 ConversionPatternRewriter
&rewriter
) const override
;
237 /// Converts memref.load to spirv.Load + spirv.AccessChain.
238 class LoadOpPattern final
: public OpConversionPattern
<memref::LoadOp
> {
240 using OpConversionPattern
<memref::LoadOp
>::OpConversionPattern
;
243 matchAndRewrite(memref::LoadOp loadOp
, OpAdaptor adaptor
,
244 ConversionPatternRewriter
&rewriter
) const override
;
247 /// Converts memref.store to spirv.Store on integers.
248 class IntStoreOpPattern final
: public OpConversionPattern
<memref::StoreOp
> {
250 using OpConversionPattern
<memref::StoreOp
>::OpConversionPattern
;
253 matchAndRewrite(memref::StoreOp storeOp
, OpAdaptor adaptor
,
254 ConversionPatternRewriter
&rewriter
) const override
;
257 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
258 class MemorySpaceCastOpPattern final
259 : public OpConversionPattern
<memref::MemorySpaceCastOp
> {
261 using OpConversionPattern
<memref::MemorySpaceCastOp
>::OpConversionPattern
;
264 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp
, OpAdaptor adaptor
,
265 ConversionPatternRewriter
&rewriter
) const override
;
268 /// Converts memref.store to spirv.Store.
269 class StoreOpPattern final
: public OpConversionPattern
<memref::StoreOp
> {
271 using OpConversionPattern
<memref::StoreOp
>::OpConversionPattern
;
274 matchAndRewrite(memref::StoreOp storeOp
, OpAdaptor adaptor
,
275 ConversionPatternRewriter
&rewriter
) const override
;
278 class ReinterpretCastPattern final
279 : public OpConversionPattern
<memref::ReinterpretCastOp
> {
281 using OpConversionPattern::OpConversionPattern
;
284 matchAndRewrite(memref::ReinterpretCastOp op
, OpAdaptor adaptor
,
285 ConversionPatternRewriter
&rewriter
) const override
;
288 class CastPattern final
: public OpConversionPattern
<memref::CastOp
> {
290 using OpConversionPattern::OpConversionPattern
;
293 matchAndRewrite(memref::CastOp op
, OpAdaptor adaptor
,
294 ConversionPatternRewriter
&rewriter
) const override
{
295 Value src
= adaptor
.getSource();
296 Type srcType
= src
.getType();
298 const TypeConverter
*converter
= getTypeConverter();
299 Type dstType
= converter
->convertType(op
.getType());
300 if (srcType
!= dstType
)
301 return rewriter
.notifyMatchFailure(op
, [&](Diagnostic
&diag
) {
302 diag
<< "types doesn't match: " << srcType
<< " and " << dstType
;
305 rewriter
.replaceOp(op
, src
);
312 //===----------------------------------------------------------------------===//
314 //===----------------------------------------------------------------------===//
317 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp
, OpAdaptor adaptor
,
318 ConversionPatternRewriter
&rewriter
) const {
319 MemRefType allocType
= allocaOp
.getType();
320 if (!isAllocationSupported(allocaOp
, allocType
))
321 return rewriter
.notifyMatchFailure(allocaOp
, "unhandled allocation type");
323 // Get the SPIR-V type for the allocation.
324 Type spirvType
= getTypeConverter()->convertType(allocType
);
326 return rewriter
.notifyMatchFailure(allocaOp
, "type conversion failed");
328 rewriter
.replaceOpWithNewOp
<spirv::VariableOp
>(allocaOp
, spirvType
,
329 spirv::StorageClass::Function
,
330 /*initializer=*/nullptr);
334 //===----------------------------------------------------------------------===//
336 //===----------------------------------------------------------------------===//
339 AllocOpPattern::matchAndRewrite(memref::AllocOp operation
, OpAdaptor adaptor
,
340 ConversionPatternRewriter
&rewriter
) const {
341 MemRefType allocType
= operation
.getType();
342 if (!isAllocationSupported(operation
, allocType
))
343 return rewriter
.notifyMatchFailure(operation
, "unhandled allocation type");
345 // Get the SPIR-V type for the allocation.
346 Type spirvType
= getTypeConverter()->convertType(allocType
);
348 return rewriter
.notifyMatchFailure(operation
, "type conversion failed");
350 // Insert spirv.GlobalVariable for this allocation.
352 SymbolTable::getNearestSymbolTable(operation
->getParentOp());
355 Location loc
= operation
.getLoc();
356 spirv::GlobalVariableOp varOp
;
358 OpBuilder::InsertionGuard
guard(rewriter
);
359 Block
&entryBlock
= *parent
->getRegion(0).begin();
360 rewriter
.setInsertionPointToStart(&entryBlock
);
361 auto varOps
= entryBlock
.getOps
<spirv::GlobalVariableOp
>();
362 std::string varName
=
363 std::string("__workgroup_mem__") +
364 std::to_string(std::distance(varOps
.begin(), varOps
.end()));
365 varOp
= rewriter
.create
<spirv::GlobalVariableOp
>(loc
, spirvType
, varName
,
366 /*initializer=*/nullptr);
369 // Get pointer to global variable at the current scope.
370 rewriter
.replaceOpWithNewOp
<spirv::AddressOfOp
>(operation
, varOp
);
374 //===----------------------------------------------------------------------===//
376 //===----------------------------------------------------------------------===//
379 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp
,
381 ConversionPatternRewriter
&rewriter
) const {
382 if (isa
<FloatType
>(atomicOp
.getType()))
383 return rewriter
.notifyMatchFailure(atomicOp
,
384 "unimplemented floating-point case");
386 auto memrefType
= cast
<MemRefType
>(atomicOp
.getMemref().getType());
387 std::optional
<spirv::Scope
> scope
= getAtomicOpScope(memrefType
);
389 return rewriter
.notifyMatchFailure(atomicOp
,
390 "unsupported memref memory space");
392 auto &typeConverter
= *getTypeConverter
<SPIRVTypeConverter
>();
393 Type resultType
= typeConverter
.convertType(atomicOp
.getType());
395 return rewriter
.notifyMatchFailure(atomicOp
,
396 "failed to convert result type");
398 auto loc
= atomicOp
.getLoc();
400 spirv::getElementPtr(typeConverter
, memrefType
, adaptor
.getMemref(),
401 adaptor
.getIndices(), loc
, rewriter
);
406 #define ATOMIC_CASE(kind, spirvOp) \
407 case arith::AtomicRMWKind::kind: \
408 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
409 atomicOp, resultType, ptr, *scope, \
410 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
413 switch (atomicOp
.getKind()) {
414 ATOMIC_CASE(addi
, AtomicIAddOp
);
415 ATOMIC_CASE(maxs
, AtomicSMaxOp
);
416 ATOMIC_CASE(maxu
, AtomicUMaxOp
);
417 ATOMIC_CASE(mins
, AtomicSMinOp
);
418 ATOMIC_CASE(minu
, AtomicUMinOp
);
419 ATOMIC_CASE(ori
, AtomicOrOp
);
420 ATOMIC_CASE(andi
, AtomicAndOp
);
422 return rewriter
.notifyMatchFailure(atomicOp
, "unimplemented atomic kind");
430 //===----------------------------------------------------------------------===//
432 //===----------------------------------------------------------------------===//
435 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation
,
437 ConversionPatternRewriter
&rewriter
) const {
438 MemRefType deallocType
= cast
<MemRefType
>(operation
.getMemref().getType());
439 if (!isAllocationSupported(operation
, deallocType
))
440 return rewriter
.notifyMatchFailure(operation
, "unhandled allocation type");
441 rewriter
.eraseOp(operation
);
445 //===----------------------------------------------------------------------===//
447 //===----------------------------------------------------------------------===//
449 struct MemoryRequirements
{
450 spirv::MemoryAccessAttr memoryAccess
;
451 IntegerAttr alignment
;
454 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
456 static FailureOr
<MemoryRequirements
>
457 calculateMemoryRequirements(Value accessedPtr
, bool isNontemporal
) {
458 MLIRContext
*ctx
= accessedPtr
.getContext();
460 auto memoryAccess
= spirv::MemoryAccess::None
;
462 memoryAccess
= spirv::MemoryAccess::Nontemporal
;
465 auto ptrType
= cast
<spirv::PointerType
>(accessedPtr
.getType());
466 if (ptrType
.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer
) {
467 if (memoryAccess
== spirv::MemoryAccess::None
) {
468 return MemoryRequirements
{spirv::MemoryAccessAttr
{}, IntegerAttr
{}};
470 return MemoryRequirements
{spirv::MemoryAccessAttr::get(ctx
, memoryAccess
),
474 // PhysicalStorageBuffers require the `Aligned` attribute.
475 auto pointeeType
= dyn_cast
<spirv::ScalarType
>(ptrType
.getPointeeType());
479 // For scalar types, the alignment is determined by their size.
480 std::optional
<int64_t> sizeInBytes
= pointeeType
.getSizeInBytes();
481 if (!sizeInBytes
.has_value())
484 memoryAccess
= memoryAccess
| spirv::MemoryAccess::Aligned
;
485 auto memAccessAttr
= spirv::MemoryAccessAttr::get(ctx
, memoryAccess
);
486 auto alignment
= IntegerAttr::get(IntegerType::get(ctx
, 32), *sizeInBytes
);
487 return MemoryRequirements
{memAccessAttr
, alignment
};
490 /// Given an accessed SPIR-V pointer and the original memref load/store
491 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
492 /// account the alignment attributes applied to the load/store op.
493 template <class LoadOrStoreOp
>
494 static FailureOr
<MemoryRequirements
>
495 calculateMemoryRequirements(Value accessedPtr
, LoadOrStoreOp loadOrStoreOp
) {
497 llvm::is_one_of
<LoadOrStoreOp
, memref::LoadOp
, memref::StoreOp
>::value
,
498 "Must be called on either memref::LoadOp or memref::StoreOp");
500 Operation
*memrefAccessOp
= loadOrStoreOp
.getOperation();
501 auto memrefMemAccess
= memrefAccessOp
->getAttrOfType
<spirv::MemoryAccessAttr
>(
502 spirv::attributeName
<spirv::MemoryAccess
>());
503 auto memrefAlignment
=
504 memrefAccessOp
->getAttrOfType
<IntegerAttr
>("alignment");
505 if (memrefMemAccess
&& memrefAlignment
)
506 return MemoryRequirements
{memrefMemAccess
, memrefAlignment
};
508 return calculateMemoryRequirements(accessedPtr
,
509 loadOrStoreOp
.getNontemporal());
513 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp
, OpAdaptor adaptor
,
514 ConversionPatternRewriter
&rewriter
) const {
515 auto loc
= loadOp
.getLoc();
516 auto memrefType
= cast
<MemRefType
>(loadOp
.getMemref().getType());
517 if (!memrefType
.getElementType().isSignlessInteger())
520 const auto &typeConverter
= *getTypeConverter
<SPIRVTypeConverter
>();
522 spirv::getElementPtr(typeConverter
, memrefType
, adaptor
.getMemref(),
523 adaptor
.getIndices(), loc
, rewriter
);
528 int srcBits
= memrefType
.getElementType().getIntOrFloatBitWidth();
529 bool isBool
= srcBits
== 1;
531 srcBits
= typeConverter
.getOptions().boolNumBits
;
533 auto pointerType
= typeConverter
.convertType
<spirv::PointerType
>(memrefType
);
535 return rewriter
.notifyMatchFailure(loadOp
, "failed to convert memref type");
537 Type pointeeType
= pointerType
.getPointeeType();
539 if (typeConverter
.allows(spirv::Capability::Kernel
)) {
540 if (auto arrayType
= dyn_cast
<spirv::ArrayType
>(pointeeType
))
541 dstType
= arrayType
.getElementType();
543 dstType
= pointeeType
;
545 // For Vulkan we need to extract element from wrapping struct and array.
546 Type structElemType
=
547 cast
<spirv::StructType
>(pointeeType
).getElementType(0);
548 if (auto arrayType
= dyn_cast
<spirv::ArrayType
>(structElemType
))
549 dstType
= arrayType
.getElementType();
551 dstType
= cast
<spirv::RuntimeArrayType
>(structElemType
).getElementType();
553 int dstBits
= dstType
.getIntOrFloatBitWidth();
554 assert(dstBits
% srcBits
== 0);
556 // If the rewritten load op has the same bit width, use the loading value
558 if (srcBits
== dstBits
) {
559 auto memoryRequirements
= calculateMemoryRequirements(accessChain
, loadOp
);
560 if (failed(memoryRequirements
))
561 return rewriter
.notifyMatchFailure(
562 loadOp
, "failed to determine memory requirements");
564 auto [memoryAccess
, alignment
] = *memoryRequirements
;
565 Value loadVal
= rewriter
.create
<spirv::LoadOp
>(loc
, accessChain
,
566 memoryAccess
, alignment
);
568 loadVal
= castIntNToBool(loc
, loadVal
, rewriter
);
569 rewriter
.replaceOp(loadOp
, loadVal
);
573 // Bitcasting is currently unsupported for Kernel capability /
574 // spirv.PtrAccessChain.
575 if (typeConverter
.allows(spirv::Capability::Kernel
))
578 auto accessChainOp
= accessChain
.getDefiningOp
<spirv::AccessChainOp
>();
582 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
583 // still returns a linearized accessing. If the accessing is not linearized,
584 // there will be offset issues.
585 assert(accessChainOp
.getIndices().size() == 2);
586 Value adjustedPtr
= adjustAccessChainForBitwidth(typeConverter
, accessChainOp
,
587 srcBits
, dstBits
, rewriter
);
588 auto memoryRequirements
= calculateMemoryRequirements(adjustedPtr
, loadOp
);
589 if (failed(memoryRequirements
))
590 return rewriter
.notifyMatchFailure(
591 loadOp
, "failed to determine memory requirements");
593 auto [memoryAccess
, alignment
] = *memoryRequirements
;
594 Value spvLoadOp
= rewriter
.create
<spirv::LoadOp
>(loc
, dstType
, adjustedPtr
,
595 memoryAccess
, alignment
);
597 // Shift the bits to the rightmost.
598 // ____XXXX________ -> ____________XXXX
599 Value lastDim
= accessChainOp
->getOperand(accessChainOp
.getNumOperands() - 1);
600 Value offset
= getOffsetForBitwidth(loc
, lastDim
, srcBits
, dstBits
, rewriter
);
601 Value result
= rewriter
.createOrFold
<spirv::ShiftRightArithmeticOp
>(
602 loc
, spvLoadOp
.getType(), spvLoadOp
, offset
);
604 // Apply the mask to extract corresponding bits.
605 Value mask
= rewriter
.createOrFold
<spirv::ConstantOp
>(
606 loc
, dstType
, rewriter
.getIntegerAttr(dstType
, (1 << srcBits
) - 1));
608 rewriter
.createOrFold
<spirv::BitwiseAndOp
>(loc
, dstType
, result
, mask
);
610 // Apply sign extension on the loading value unconditionally. The signedness
611 // semantic is carried in the operator itself, we relies other pattern to
612 // handle the casting.
613 IntegerAttr shiftValueAttr
=
614 rewriter
.getIntegerAttr(dstType
, dstBits
- srcBits
);
616 rewriter
.createOrFold
<spirv::ConstantOp
>(loc
, dstType
, shiftValueAttr
);
617 result
= rewriter
.createOrFold
<spirv::ShiftLeftLogicalOp
>(loc
, dstType
,
619 result
= rewriter
.createOrFold
<spirv::ShiftRightArithmeticOp
>(
620 loc
, dstType
, result
, shiftValue
);
622 rewriter
.replaceOp(loadOp
, result
);
624 assert(accessChainOp
.use_empty());
625 rewriter
.eraseOp(accessChainOp
);
631 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp
, OpAdaptor adaptor
,
632 ConversionPatternRewriter
&rewriter
) const {
633 auto memrefType
= cast
<MemRefType
>(loadOp
.getMemref().getType());
634 if (memrefType
.getElementType().isSignlessInteger())
636 Value loadPtr
= spirv::getElementPtr(
637 *getTypeConverter
<SPIRVTypeConverter
>(), memrefType
, adaptor
.getMemref(),
638 adaptor
.getIndices(), loadOp
.getLoc(), rewriter
);
643 auto memoryRequirements
= calculateMemoryRequirements(loadPtr
, loadOp
);
644 if (failed(memoryRequirements
))
645 return rewriter
.notifyMatchFailure(
646 loadOp
, "failed to determine memory requirements");
648 auto [memoryAccess
, alignment
] = *memoryRequirements
;
649 rewriter
.replaceOpWithNewOp
<spirv::LoadOp
>(loadOp
, loadPtr
, memoryAccess
,
655 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp
, OpAdaptor adaptor
,
656 ConversionPatternRewriter
&rewriter
) const {
657 auto memrefType
= cast
<MemRefType
>(storeOp
.getMemref().getType());
658 if (!memrefType
.getElementType().isSignlessInteger())
659 return rewriter
.notifyMatchFailure(storeOp
,
660 "element type is not a signless int");
662 auto loc
= storeOp
.getLoc();
663 auto &typeConverter
= *getTypeConverter
<SPIRVTypeConverter
>();
665 spirv::getElementPtr(typeConverter
, memrefType
, adaptor
.getMemref(),
666 adaptor
.getIndices(), loc
, rewriter
);
669 return rewriter
.notifyMatchFailure(
670 storeOp
, "failed to convert element pointer type");
672 int srcBits
= memrefType
.getElementType().getIntOrFloatBitWidth();
674 bool isBool
= srcBits
== 1;
676 srcBits
= typeConverter
.getOptions().boolNumBits
;
678 auto pointerType
= typeConverter
.convertType
<spirv::PointerType
>(memrefType
);
680 return rewriter
.notifyMatchFailure(storeOp
,
681 "failed to convert memref type");
683 Type pointeeType
= pointerType
.getPointeeType();
685 if (typeConverter
.allows(spirv::Capability::Kernel
)) {
686 if (auto arrayType
= dyn_cast
<spirv::ArrayType
>(pointeeType
))
687 dstType
= dyn_cast
<IntegerType
>(arrayType
.getElementType());
689 dstType
= dyn_cast
<IntegerType
>(pointeeType
);
691 // For Vulkan we need to extract element from wrapping struct and array.
692 Type structElemType
=
693 cast
<spirv::StructType
>(pointeeType
).getElementType(0);
694 if (auto arrayType
= dyn_cast
<spirv::ArrayType
>(structElemType
))
695 dstType
= dyn_cast
<IntegerType
>(arrayType
.getElementType());
697 dstType
= dyn_cast
<IntegerType
>(
698 cast
<spirv::RuntimeArrayType
>(structElemType
).getElementType());
702 return rewriter
.notifyMatchFailure(
703 storeOp
, "failed to determine destination element type");
705 int dstBits
= static_cast<int>(dstType
.getWidth());
706 assert(dstBits
% srcBits
== 0);
708 if (srcBits
== dstBits
) {
709 auto memoryRequirements
= calculateMemoryRequirements(accessChain
, storeOp
);
710 if (failed(memoryRequirements
))
711 return rewriter
.notifyMatchFailure(
712 storeOp
, "failed to determine memory requirements");
714 auto [memoryAccess
, alignment
] = *memoryRequirements
;
715 Value storeVal
= adaptor
.getValue();
717 storeVal
= castBoolToIntN(loc
, storeVal
, dstType
, rewriter
);
718 rewriter
.replaceOpWithNewOp
<spirv::StoreOp
>(storeOp
, accessChain
, storeVal
,
719 memoryAccess
, alignment
);
723 // Bitcasting is currently unsupported for Kernel capability /
724 // spirv.PtrAccessChain.
725 if (typeConverter
.allows(spirv::Capability::Kernel
))
728 auto accessChainOp
= accessChain
.getDefiningOp
<spirv::AccessChainOp
>();
732 // Since there are multiple threads in the processing, the emulation will be
733 // done with atomic operations. E.g., if the stored value is i8, rewrite the
735 // 1) load a 32-bit integer
736 // 2) clear 8 bits in the loaded value
737 // 3) set 8 bits in the loaded value
738 // 4) store 32-bit value back
740 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
741 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
743 assert(accessChainOp
.getIndices().size() == 2);
744 Value lastDim
= accessChainOp
->getOperand(accessChainOp
.getNumOperands() - 1);
745 Value offset
= getOffsetForBitwidth(loc
, lastDim
, srcBits
, dstBits
, rewriter
);
747 // Create a mask to clear the destination. E.g., if it is the second i8 in
748 // i32, 0xFFFF00FF is created.
749 Value mask
= rewriter
.createOrFold
<spirv::ConstantOp
>(
750 loc
, dstType
, rewriter
.getIntegerAttr(dstType
, (1 << srcBits
) - 1));
751 Value clearBitsMask
= rewriter
.createOrFold
<spirv::ShiftLeftLogicalOp
>(
752 loc
, dstType
, mask
, offset
);
754 rewriter
.createOrFold
<spirv::NotOp
>(loc
, dstType
, clearBitsMask
);
756 Value storeVal
= shiftValue(loc
, adaptor
.getValue(), offset
, mask
, rewriter
);
757 Value adjustedPtr
= adjustAccessChainForBitwidth(typeConverter
, accessChainOp
,
758 srcBits
, dstBits
, rewriter
);
759 std::optional
<spirv::Scope
> scope
= getAtomicOpScope(memrefType
);
761 return rewriter
.notifyMatchFailure(storeOp
, "atomic scope not available");
763 Value result
= rewriter
.create
<spirv::AtomicAndOp
>(
764 loc
, dstType
, adjustedPtr
, *scope
, spirv::MemorySemantics::AcquireRelease
,
766 result
= rewriter
.create
<spirv::AtomicOrOp
>(
767 loc
, dstType
, adjustedPtr
, *scope
, spirv::MemorySemantics::AcquireRelease
,
770 // The AtomicOrOp has no side effect. Since it is already inserted, we can
771 // just remove the original StoreOp. Note that rewriter.replaceOp()
772 // doesn't work because it only accepts that the numbers of result are the
774 rewriter
.eraseOp(storeOp
);
776 assert(accessChainOp
.use_empty());
777 rewriter
.eraseOp(accessChainOp
);
782 //===----------------------------------------------------------------------===//
784 //===----------------------------------------------------------------------===//
786 LogicalResult
MemorySpaceCastOpPattern::matchAndRewrite(
787 memref::MemorySpaceCastOp addrCastOp
, OpAdaptor adaptor
,
788 ConversionPatternRewriter
&rewriter
) const {
789 Location loc
= addrCastOp
.getLoc();
790 auto &typeConverter
= *getTypeConverter
<SPIRVTypeConverter
>();
791 if (!typeConverter
.allows(spirv::Capability::Kernel
))
792 return rewriter
.notifyMatchFailure(
793 loc
, "address space casts require kernel capability");
795 auto sourceType
= dyn_cast
<MemRefType
>(addrCastOp
.getSource().getType());
797 return rewriter
.notifyMatchFailure(
798 loc
, "SPIR-V lowering requires ranked memref types");
799 auto resultType
= cast
<MemRefType
>(addrCastOp
.getResult().getType());
801 auto sourceStorageClassAttr
=
802 dyn_cast_or_null
<spirv::StorageClassAttr
>(sourceType
.getMemorySpace());
803 if (!sourceStorageClassAttr
)
804 return rewriter
.notifyMatchFailure(loc
, [sourceType
](Diagnostic
&diag
) {
805 diag
<< "source address space " << sourceType
.getMemorySpace()
806 << " must be a SPIR-V storage class";
808 auto resultStorageClassAttr
=
809 dyn_cast_or_null
<spirv::StorageClassAttr
>(resultType
.getMemorySpace());
810 if (!resultStorageClassAttr
)
811 return rewriter
.notifyMatchFailure(loc
, [resultType
](Diagnostic
&diag
) {
812 diag
<< "result address space " << resultType
.getMemorySpace()
813 << " must be a SPIR-V storage class";
816 spirv::StorageClass sourceSc
= sourceStorageClassAttr
.getValue();
817 spirv::StorageClass resultSc
= resultStorageClassAttr
.getValue();
819 Value result
= adaptor
.getSource();
820 Type resultPtrType
= typeConverter
.convertType(resultType
);
822 return rewriter
.notifyMatchFailure(addrCastOp
,
823 "failed to convert memref type");
825 Type genericPtrType
= resultPtrType
;
826 // SPIR-V doesn't have a general address space cast operation. Instead, it has
827 // conversions to and from generic pointers. To implement the general case,
828 // we use specific-to-generic conversions when the source class is not
829 // generic. Then when the result storage class is not generic, we convert the
830 // generic pointer (either the input on ar intermediate result) to that
831 // class. This also means that we'll need the intermediate generic pointer
832 // type if neither the source or destination have it.
833 if (sourceSc
!= spirv::StorageClass::Generic
&&
834 resultSc
!= spirv::StorageClass::Generic
) {
835 Type intermediateType
=
836 MemRefType::get(sourceType
.getShape(), sourceType
.getElementType(),
837 sourceType
.getLayout(),
838 rewriter
.getAttr
<spirv::StorageClassAttr
>(
839 spirv::StorageClass::Generic
));
840 genericPtrType
= typeConverter
.convertType(intermediateType
);
842 if (sourceSc
!= spirv::StorageClass::Generic
) {
844 rewriter
.create
<spirv::PtrCastToGenericOp
>(loc
, genericPtrType
, result
);
846 if (resultSc
!= spirv::StorageClass::Generic
) {
848 rewriter
.create
<spirv::GenericCastToPtrOp
>(loc
, resultPtrType
, result
);
850 rewriter
.replaceOp(addrCastOp
, result
);
855 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp
, OpAdaptor adaptor
,
856 ConversionPatternRewriter
&rewriter
) const {
857 auto memrefType
= cast
<MemRefType
>(storeOp
.getMemref().getType());
858 if (memrefType
.getElementType().isSignlessInteger())
859 return rewriter
.notifyMatchFailure(storeOp
, "signless int");
860 auto storePtr
= spirv::getElementPtr(
861 *getTypeConverter
<SPIRVTypeConverter
>(), memrefType
, adaptor
.getMemref(),
862 adaptor
.getIndices(), storeOp
.getLoc(), rewriter
);
865 return rewriter
.notifyMatchFailure(storeOp
, "type conversion failed");
867 auto memoryRequirements
= calculateMemoryRequirements(storePtr
, storeOp
);
868 if (failed(memoryRequirements
))
869 return rewriter
.notifyMatchFailure(
870 storeOp
, "failed to determine memory requirements");
872 auto [memoryAccess
, alignment
] = *memoryRequirements
;
873 rewriter
.replaceOpWithNewOp
<spirv::StoreOp
>(
874 storeOp
, storePtr
, adaptor
.getValue(), memoryAccess
, alignment
);
878 LogicalResult
ReinterpretCastPattern::matchAndRewrite(
879 memref::ReinterpretCastOp op
, OpAdaptor adaptor
,
880 ConversionPatternRewriter
&rewriter
) const {
881 Value src
= adaptor
.getSource();
882 auto srcType
= dyn_cast
<spirv::PointerType
>(src
.getType());
885 return rewriter
.notifyMatchFailure(op
, [&](Diagnostic
&diag
) {
886 diag
<< "invalid src type " << src
.getType();
889 const TypeConverter
*converter
= getTypeConverter();
891 auto dstType
= converter
->convertType
<spirv::PointerType
>(op
.getType());
892 if (dstType
!= srcType
)
893 return rewriter
.notifyMatchFailure(op
, [&](Diagnostic
&diag
) {
894 diag
<< "invalid dst type " << op
.getType();
897 OpFoldResult offset
=
898 getMixedValues(adaptor
.getStaticOffsets(), adaptor
.getOffsets(), rewriter
)
900 if (isConstantIntValue(offset
, 0)) {
901 rewriter
.replaceOp(op
, src
);
905 Type intType
= converter
->convertType(rewriter
.getIndexType());
907 return rewriter
.notifyMatchFailure(op
, "failed to convert index type");
909 Location loc
= op
.getLoc();
910 auto offsetValue
= [&]() -> Value
{
911 if (auto val
= dyn_cast
<Value
>(offset
))
914 int64_t attrVal
= cast
<IntegerAttr
>(offset
.get
<Attribute
>()).getInt();
915 Attribute attr
= rewriter
.getIntegerAttr(intType
, attrVal
);
916 return rewriter
.createOrFold
<spirv::ConstantOp
>(loc
, intType
, attr
);
919 rewriter
.replaceOpWithNewOp
<spirv::InBoundsPtrAccessChainOp
>(
920 op
, src
, offsetValue
, std::nullopt
);
924 //===----------------------------------------------------------------------===//
925 // Pattern population
926 //===----------------------------------------------------------------------===//
929 void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter
&typeConverter
,
930 RewritePatternSet
&patterns
) {
931 patterns
.add
<AllocaOpPattern
, AllocOpPattern
, AtomicRMWOpPattern
,
932 DeallocOpPattern
, IntLoadOpPattern
, IntStoreOpPattern
,
933 LoadOpPattern
, MemorySpaceCastOpPattern
, StoreOpPattern
,
934 ReinterpretCastPattern
, CastPattern
>(typeConverter
,
935 patterns
.getContext());