Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / MemRefToSPIRV / MemRefToSPIRV.cpp
blob49a391938eaf6976c1778f65fe038494a74dfead
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
24 #include "llvm/Support/Debug.h"
25 #include <cassert>
26 #include <optional>
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
30 using namespace mlir;
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
36 /// Returns the offset of the value in `targetBits` representation.
37 ///
38 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39 /// It's assumed to be non-negative.
40 ///
41 /// When accessing an element in the array treating as having elements of
42 /// `targetBits`, multiple values are loaded in the same time. The method
43 /// returns the offset where the `srcIdx` locates in the value. For example, if
44 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
46 /// element has 8 bits.
47 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48 int targetBits, OpBuilder &builder) {
49 assert(targetBits % sourceBits == 0);
50 Type type = srcIdx.getType();
51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54 auto srcBitsValue =
55 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
60 /// Returns an adjusted spirv::AccessChainOp. Based on the
61 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62 /// supported. During conversion if a memref of an unsupported type is used,
63 /// load/stores to this memref need to be modified to use a supported higher
64 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
65 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66 /// the bits needed. The extraction of the actual bits needed are handled
67 /// separately. Note that this only works for a 1-D tensor.
68 static Value
69 adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
70 spirv::AccessChainOp op, int sourceBits,
71 int targetBits, OpBuilder &builder) {
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
75 Type type = lastDim.getType();
76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
79 // There are two elements if this is a 1-D tensor.
80 assert(indices.size() == 2);
81 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
86 /// Casts the given `srcBool` into an integer of `dstType`.
87 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
88 OpBuilder &builder) {
89 assert(srcBool.getType().isInteger(1));
90 if (dstType.isInteger(1))
91 return srcBool;
92 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
93 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
95 zero);
98 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
99 /// to the type destination type, and masked.
100 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
101 OpBuilder &builder) {
102 IntegerType dstType = cast<IntegerType>(mask.getType());
103 int targetBits = static_cast<int>(dstType.getWidth());
104 int valueBits = value.getType().getIntOrFloatBitWidth();
105 assert(valueBits <= targetBits);
107 if (valueBits == 1) {
108 value = castBoolToIntN(loc, value, dstType, builder);
109 } else {
110 if (valueBits < targetBits) {
111 value = builder.create<spirv::UConvertOp>(
112 loc, builder.getIntegerType(targetBits), value);
115 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
118 value, offset);
121 /// Returns true if the allocations of memref `type` generated from `allocOp`
122 /// can be lowered to SPIR-V.
123 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
124 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
125 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
126 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
127 return false;
128 } else if (isa<memref::AllocaOp>(allocOp)) {
129 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
130 if (!sc || sc.getValue() != spirv::StorageClass::Function)
131 return false;
132 } else {
133 return false;
136 // Currently only support static shape and int or float or vector of int or
137 // float element type.
138 if (!type.hasStaticShape())
139 return false;
141 Type elementType = type.getElementType();
142 if (auto vecType = dyn_cast<VectorType>(elementType))
143 elementType = vecType.getElementType();
144 return elementType.isIntOrFloat();
147 /// Returns the scope to use for atomic operations use for emulating store
148 /// operations of unsupported integer bitwidths, based on the memref
149 /// type. Returns std::nullopt on failure.
150 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
151 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
152 switch (sc.getValue()) {
153 case spirv::StorageClass::StorageBuffer:
154 return spirv::Scope::Device;
155 case spirv::StorageClass::Workgroup:
156 return spirv::Scope::Workgroup;
157 default:
158 break;
160 return {};
163 /// Casts the given `srcInt` into a boolean value.
164 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165 if (srcInt.getType().isInteger(1))
166 return srcInt;
168 auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
169 return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
172 //===----------------------------------------------------------------------===//
173 // Operation conversion
174 //===----------------------------------------------------------------------===//
176 // Note that DRR cannot be used for the patterns in this file: we may need to
177 // convert type along the way, which requires ConversionPattern. DRR generates
178 // normal RewritePattern.
180 namespace {
182 /// Converts memref.alloca to SPIR-V Function variables.
183 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
184 public:
185 using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
187 LogicalResult
188 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
189 ConversionPatternRewriter &rewriter) const override;
192 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
193 /// to Workgroup memory when the size is constant. Note that this pattern needs
194 /// to be applied in a pass that runs at least at spirv.module scope since it
195 /// wil ladd global variables into the spirv.module.
196 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
197 public:
198 using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
200 LogicalResult
201 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter) const override;
205 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
206 class AtomicRMWOpPattern final
207 : public OpConversionPattern<memref::AtomicRMWOp> {
208 public:
209 using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
211 LogicalResult
212 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter) const override;
216 /// Removed a deallocation if it is a supported allocation. Currently only
217 /// removes deallocation if the memory space is workgroup memory.
218 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
219 public:
220 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
222 LogicalResult
223 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
224 ConversionPatternRewriter &rewriter) const override;
227 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
228 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
229 public:
230 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
232 LogicalResult
233 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter) const override;
237 /// Converts memref.load to spirv.Load + spirv.AccessChain.
238 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
239 public:
240 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
242 LogicalResult
243 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
244 ConversionPatternRewriter &rewriter) const override;
247 /// Converts memref.store to spirv.Store on integers.
248 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
249 public:
250 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
252 LogicalResult
253 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter) const override;
257 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
258 class MemorySpaceCastOpPattern final
259 : public OpConversionPattern<memref::MemorySpaceCastOp> {
260 public:
261 using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
263 LogicalResult
264 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override;
268 /// Converts memref.store to spirv.Store.
269 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
270 public:
271 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
273 LogicalResult
274 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override;
278 class ReinterpretCastPattern final
279 : public OpConversionPattern<memref::ReinterpretCastOp> {
280 public:
281 using OpConversionPattern::OpConversionPattern;
283 LogicalResult
284 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter) const override;
288 class CastPattern final : public OpConversionPattern<memref::CastOp> {
289 public:
290 using OpConversionPattern::OpConversionPattern;
292 LogicalResult
293 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
294 ConversionPatternRewriter &rewriter) const override {
295 Value src = adaptor.getSource();
296 Type srcType = src.getType();
298 const TypeConverter *converter = getTypeConverter();
299 Type dstType = converter->convertType(op.getType());
300 if (srcType != dstType)
301 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
302 diag << "types doesn't match: " << srcType << " and " << dstType;
305 rewriter.replaceOp(op, src);
306 return success();
310 } // namespace
312 //===----------------------------------------------------------------------===//
313 // AllocaOp
314 //===----------------------------------------------------------------------===//
316 LogicalResult
317 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter) const {
319 MemRefType allocType = allocaOp.getType();
320 if (!isAllocationSupported(allocaOp, allocType))
321 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
323 // Get the SPIR-V type for the allocation.
324 Type spirvType = getTypeConverter()->convertType(allocType);
325 if (!spirvType)
326 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
328 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
329 spirv::StorageClass::Function,
330 /*initializer=*/nullptr);
331 return success();
334 //===----------------------------------------------------------------------===//
335 // AllocOp
336 //===----------------------------------------------------------------------===//
338 LogicalResult
339 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const {
341 MemRefType allocType = operation.getType();
342 if (!isAllocationSupported(operation, allocType))
343 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
345 // Get the SPIR-V type for the allocation.
346 Type spirvType = getTypeConverter()->convertType(allocType);
347 if (!spirvType)
348 return rewriter.notifyMatchFailure(operation, "type conversion failed");
350 // Insert spirv.GlobalVariable for this allocation.
351 Operation *parent =
352 SymbolTable::getNearestSymbolTable(operation->getParentOp());
353 if (!parent)
354 return failure();
355 Location loc = operation.getLoc();
356 spirv::GlobalVariableOp varOp;
358 OpBuilder::InsertionGuard guard(rewriter);
359 Block &entryBlock = *parent->getRegion(0).begin();
360 rewriter.setInsertionPointToStart(&entryBlock);
361 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
362 std::string varName =
363 std::string("__workgroup_mem__") +
364 std::to_string(std::distance(varOps.begin(), varOps.end()));
365 varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
366 /*initializer=*/nullptr);
369 // Get pointer to global variable at the current scope.
370 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
371 return success();
374 //===----------------------------------------------------------------------===//
375 // AllocOp
376 //===----------------------------------------------------------------------===//
378 LogicalResult
379 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
380 OpAdaptor adaptor,
381 ConversionPatternRewriter &rewriter) const {
382 if (isa<FloatType>(atomicOp.getType()))
383 return rewriter.notifyMatchFailure(atomicOp,
384 "unimplemented floating-point case");
386 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
387 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
388 if (!scope)
389 return rewriter.notifyMatchFailure(atomicOp,
390 "unsupported memref memory space");
392 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
393 Type resultType = typeConverter.convertType(atomicOp.getType());
394 if (!resultType)
395 return rewriter.notifyMatchFailure(atomicOp,
396 "failed to convert result type");
398 auto loc = atomicOp.getLoc();
399 Value ptr =
400 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
401 adaptor.getIndices(), loc, rewriter);
403 if (!ptr)
404 return failure();
406 #define ATOMIC_CASE(kind, spirvOp) \
407 case arith::AtomicRMWKind::kind: \
408 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
409 atomicOp, resultType, ptr, *scope, \
410 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
411 break
413 switch (atomicOp.getKind()) {
414 ATOMIC_CASE(addi, AtomicIAddOp);
415 ATOMIC_CASE(maxs, AtomicSMaxOp);
416 ATOMIC_CASE(maxu, AtomicUMaxOp);
417 ATOMIC_CASE(mins, AtomicSMinOp);
418 ATOMIC_CASE(minu, AtomicUMinOp);
419 ATOMIC_CASE(ori, AtomicOrOp);
420 ATOMIC_CASE(andi, AtomicAndOp);
421 default:
422 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
425 #undef ATOMIC_CASE
427 return success();
430 //===----------------------------------------------------------------------===//
431 // DeallocOp
432 //===----------------------------------------------------------------------===//
434 LogicalResult
435 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
436 OpAdaptor adaptor,
437 ConversionPatternRewriter &rewriter) const {
438 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
439 if (!isAllocationSupported(operation, deallocType))
440 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
441 rewriter.eraseOp(operation);
442 return success();
445 //===----------------------------------------------------------------------===//
446 // LoadOp
447 //===----------------------------------------------------------------------===//
449 struct MemoryRequirements {
450 spirv::MemoryAccessAttr memoryAccess;
451 IntegerAttr alignment;
454 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
455 /// any.
456 static FailureOr<MemoryRequirements>
457 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
458 MLIRContext *ctx = accessedPtr.getContext();
460 auto memoryAccess = spirv::MemoryAccess::None;
461 if (isNontemporal) {
462 memoryAccess = spirv::MemoryAccess::Nontemporal;
465 auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
466 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
467 if (memoryAccess == spirv::MemoryAccess::None) {
468 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
470 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
471 IntegerAttr{}};
474 // PhysicalStorageBuffers require the `Aligned` attribute.
475 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
476 if (!pointeeType)
477 return failure();
479 // For scalar types, the alignment is determined by their size.
480 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
481 if (!sizeInBytes.has_value())
482 return failure();
484 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
485 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
486 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
487 return MemoryRequirements{memAccessAttr, alignment};
490 /// Given an accessed SPIR-V pointer and the original memref load/store
491 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
492 /// account the alignment attributes applied to the load/store op.
493 template <class LoadOrStoreOp>
494 static FailureOr<MemoryRequirements>
495 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
496 static_assert(
497 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
498 "Must be called on either memref::LoadOp or memref::StoreOp");
500 Operation *memrefAccessOp = loadOrStoreOp.getOperation();
501 auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
502 spirv::attributeName<spirv::MemoryAccess>());
503 auto memrefAlignment =
504 memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
505 if (memrefMemAccess && memrefAlignment)
506 return MemoryRequirements{memrefMemAccess, memrefAlignment};
508 return calculateMemoryRequirements(accessedPtr,
509 loadOrStoreOp.getNontemporal());
512 LogicalResult
513 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter) const {
515 auto loc = loadOp.getLoc();
516 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
517 if (!memrefType.getElementType().isSignlessInteger())
518 return failure();
520 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
521 Value accessChain =
522 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
523 adaptor.getIndices(), loc, rewriter);
525 if (!accessChain)
526 return failure();
528 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
529 bool isBool = srcBits == 1;
530 if (isBool)
531 srcBits = typeConverter.getOptions().boolNumBits;
533 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
534 if (!pointerType)
535 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
537 Type pointeeType = pointerType.getPointeeType();
538 Type dstType;
539 if (typeConverter.allows(spirv::Capability::Kernel)) {
540 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
541 dstType = arrayType.getElementType();
542 else
543 dstType = pointeeType;
544 } else {
545 // For Vulkan we need to extract element from wrapping struct and array.
546 Type structElemType =
547 cast<spirv::StructType>(pointeeType).getElementType(0);
548 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
549 dstType = arrayType.getElementType();
550 else
551 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
553 int dstBits = dstType.getIntOrFloatBitWidth();
554 assert(dstBits % srcBits == 0);
556 // If the rewritten load op has the same bit width, use the loading value
557 // directly.
558 if (srcBits == dstBits) {
559 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
560 if (failed(memoryRequirements))
561 return rewriter.notifyMatchFailure(
562 loadOp, "failed to determine memory requirements");
564 auto [memoryAccess, alignment] = *memoryRequirements;
565 Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
566 memoryAccess, alignment);
567 if (isBool)
568 loadVal = castIntNToBool(loc, loadVal, rewriter);
569 rewriter.replaceOp(loadOp, loadVal);
570 return success();
573 // Bitcasting is currently unsupported for Kernel capability /
574 // spirv.PtrAccessChain.
575 if (typeConverter.allows(spirv::Capability::Kernel))
576 return failure();
578 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
579 if (!accessChainOp)
580 return failure();
582 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
583 // still returns a linearized accessing. If the accessing is not linearized,
584 // there will be offset issues.
585 assert(accessChainOp.getIndices().size() == 2);
586 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
587 srcBits, dstBits, rewriter);
588 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
589 if (failed(memoryRequirements))
590 return rewriter.notifyMatchFailure(
591 loadOp, "failed to determine memory requirements");
593 auto [memoryAccess, alignment] = *memoryRequirements;
594 Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
595 memoryAccess, alignment);
597 // Shift the bits to the rightmost.
598 // ____XXXX________ -> ____________XXXX
599 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
600 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
601 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
602 loc, spvLoadOp.getType(), spvLoadOp, offset);
604 // Apply the mask to extract corresponding bits.
605 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
606 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
607 result =
608 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
610 // Apply sign extension on the loading value unconditionally. The signedness
611 // semantic is carried in the operator itself, we relies other pattern to
612 // handle the casting.
613 IntegerAttr shiftValueAttr =
614 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
615 Value shiftValue =
616 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
617 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
618 result, shiftValue);
619 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
620 loc, dstType, result, shiftValue);
622 rewriter.replaceOp(loadOp, result);
624 assert(accessChainOp.use_empty());
625 rewriter.eraseOp(accessChainOp);
627 return success();
630 LogicalResult
631 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
632 ConversionPatternRewriter &rewriter) const {
633 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
634 if (memrefType.getElementType().isSignlessInteger())
635 return failure();
636 Value loadPtr = spirv::getElementPtr(
637 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
638 adaptor.getIndices(), loadOp.getLoc(), rewriter);
640 if (!loadPtr)
641 return failure();
643 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
644 if (failed(memoryRequirements))
645 return rewriter.notifyMatchFailure(
646 loadOp, "failed to determine memory requirements");
648 auto [memoryAccess, alignment] = *memoryRequirements;
649 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
650 alignment);
651 return success();
654 LogicalResult
655 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
656 ConversionPatternRewriter &rewriter) const {
657 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
658 if (!memrefType.getElementType().isSignlessInteger())
659 return rewriter.notifyMatchFailure(storeOp,
660 "element type is not a signless int");
662 auto loc = storeOp.getLoc();
663 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
664 Value accessChain =
665 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
666 adaptor.getIndices(), loc, rewriter);
668 if (!accessChain)
669 return rewriter.notifyMatchFailure(
670 storeOp, "failed to convert element pointer type");
672 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
674 bool isBool = srcBits == 1;
675 if (isBool)
676 srcBits = typeConverter.getOptions().boolNumBits;
678 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
679 if (!pointerType)
680 return rewriter.notifyMatchFailure(storeOp,
681 "failed to convert memref type");
683 Type pointeeType = pointerType.getPointeeType();
684 IntegerType dstType;
685 if (typeConverter.allows(spirv::Capability::Kernel)) {
686 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
687 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
688 else
689 dstType = dyn_cast<IntegerType>(pointeeType);
690 } else {
691 // For Vulkan we need to extract element from wrapping struct and array.
692 Type structElemType =
693 cast<spirv::StructType>(pointeeType).getElementType(0);
694 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
695 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
696 else
697 dstType = dyn_cast<IntegerType>(
698 cast<spirv::RuntimeArrayType>(structElemType).getElementType());
701 if (!dstType)
702 return rewriter.notifyMatchFailure(
703 storeOp, "failed to determine destination element type");
705 int dstBits = static_cast<int>(dstType.getWidth());
706 assert(dstBits % srcBits == 0);
708 if (srcBits == dstBits) {
709 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
710 if (failed(memoryRequirements))
711 return rewriter.notifyMatchFailure(
712 storeOp, "failed to determine memory requirements");
714 auto [memoryAccess, alignment] = *memoryRequirements;
715 Value storeVal = adaptor.getValue();
716 if (isBool)
717 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
718 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
719 memoryAccess, alignment);
720 return success();
723 // Bitcasting is currently unsupported for Kernel capability /
724 // spirv.PtrAccessChain.
725 if (typeConverter.allows(spirv::Capability::Kernel))
726 return failure();
728 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
729 if (!accessChainOp)
730 return failure();
732 // Since there are multiple threads in the processing, the emulation will be
733 // done with atomic operations. E.g., if the stored value is i8, rewrite the
734 // StoreOp to:
735 // 1) load a 32-bit integer
736 // 2) clear 8 bits in the loaded value
737 // 3) set 8 bits in the loaded value
738 // 4) store 32-bit value back
740 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
741 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
742 // step.
743 assert(accessChainOp.getIndices().size() == 2);
744 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
745 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
747 // Create a mask to clear the destination. E.g., if it is the second i8 in
748 // i32, 0xFFFF00FF is created.
749 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
750 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
751 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
752 loc, dstType, mask, offset);
753 clearBitsMask =
754 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
756 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
757 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
758 srcBits, dstBits, rewriter);
759 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
760 if (!scope)
761 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
763 Value result = rewriter.create<spirv::AtomicAndOp>(
764 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
765 clearBitsMask);
766 result = rewriter.create<spirv::AtomicOrOp>(
767 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
768 storeVal);
770 // The AtomicOrOp has no side effect. Since it is already inserted, we can
771 // just remove the original StoreOp. Note that rewriter.replaceOp()
772 // doesn't work because it only accepts that the numbers of result are the
773 // same.
774 rewriter.eraseOp(storeOp);
776 assert(accessChainOp.use_empty());
777 rewriter.eraseOp(accessChainOp);
779 return success();
782 //===----------------------------------------------------------------------===//
783 // MemorySpaceCastOp
784 //===----------------------------------------------------------------------===//
786 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
787 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
788 ConversionPatternRewriter &rewriter) const {
789 Location loc = addrCastOp.getLoc();
790 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
791 if (!typeConverter.allows(spirv::Capability::Kernel))
792 return rewriter.notifyMatchFailure(
793 loc, "address space casts require kernel capability");
795 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
796 if (!sourceType)
797 return rewriter.notifyMatchFailure(
798 loc, "SPIR-V lowering requires ranked memref types");
799 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
801 auto sourceStorageClassAttr =
802 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
803 if (!sourceStorageClassAttr)
804 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
805 diag << "source address space " << sourceType.getMemorySpace()
806 << " must be a SPIR-V storage class";
808 auto resultStorageClassAttr =
809 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
810 if (!resultStorageClassAttr)
811 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
812 diag << "result address space " << resultType.getMemorySpace()
813 << " must be a SPIR-V storage class";
816 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
817 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
819 Value result = adaptor.getSource();
820 Type resultPtrType = typeConverter.convertType(resultType);
821 if (!resultPtrType)
822 return rewriter.notifyMatchFailure(addrCastOp,
823 "failed to convert memref type");
825 Type genericPtrType = resultPtrType;
826 // SPIR-V doesn't have a general address space cast operation. Instead, it has
827 // conversions to and from generic pointers. To implement the general case,
828 // we use specific-to-generic conversions when the source class is not
829 // generic. Then when the result storage class is not generic, we convert the
830 // generic pointer (either the input on ar intermediate result) to that
831 // class. This also means that we'll need the intermediate generic pointer
832 // type if neither the source or destination have it.
833 if (sourceSc != spirv::StorageClass::Generic &&
834 resultSc != spirv::StorageClass::Generic) {
835 Type intermediateType =
836 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
837 sourceType.getLayout(),
838 rewriter.getAttr<spirv::StorageClassAttr>(
839 spirv::StorageClass::Generic));
840 genericPtrType = typeConverter.convertType(intermediateType);
842 if (sourceSc != spirv::StorageClass::Generic) {
843 result =
844 rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
846 if (resultSc != spirv::StorageClass::Generic) {
847 result =
848 rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
850 rewriter.replaceOp(addrCastOp, result);
851 return success();
854 LogicalResult
855 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
856 ConversionPatternRewriter &rewriter) const {
857 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
858 if (memrefType.getElementType().isSignlessInteger())
859 return rewriter.notifyMatchFailure(storeOp, "signless int");
860 auto storePtr = spirv::getElementPtr(
861 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
862 adaptor.getIndices(), storeOp.getLoc(), rewriter);
864 if (!storePtr)
865 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
867 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
868 if (failed(memoryRequirements))
869 return rewriter.notifyMatchFailure(
870 storeOp, "failed to determine memory requirements");
872 auto [memoryAccess, alignment] = *memoryRequirements;
873 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
874 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
875 return success();
878 LogicalResult ReinterpretCastPattern::matchAndRewrite(
879 memref::ReinterpretCastOp op, OpAdaptor adaptor,
880 ConversionPatternRewriter &rewriter) const {
881 Value src = adaptor.getSource();
882 auto srcType = dyn_cast<spirv::PointerType>(src.getType());
884 if (!srcType)
885 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
886 diag << "invalid src type " << src.getType();
889 const TypeConverter *converter = getTypeConverter();
891 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
892 if (dstType != srcType)
893 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
894 diag << "invalid dst type " << op.getType();
897 OpFoldResult offset =
898 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
899 .front();
900 if (isConstantIntValue(offset, 0)) {
901 rewriter.replaceOp(op, src);
902 return success();
905 Type intType = converter->convertType(rewriter.getIndexType());
906 if (!intType)
907 return rewriter.notifyMatchFailure(op, "failed to convert index type");
909 Location loc = op.getLoc();
910 auto offsetValue = [&]() -> Value {
911 if (auto val = dyn_cast<Value>(offset))
912 return val;
914 int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
915 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
916 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
917 }();
919 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
920 op, src, offsetValue, std::nullopt);
921 return success();
924 //===----------------------------------------------------------------------===//
925 // Pattern population
926 //===----------------------------------------------------------------------===//
928 namespace mlir {
929 void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
930 RewritePatternSet &patterns) {
931 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
932 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
933 LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
934 ReinterpretCastPattern, CastPattern>(typeConverter,
935 patterns.getContext());
937 } // namespace mlir