Revert "[HLSL] Add `Increment`/`DecrementCounter` methods to structured buffers ...
[llvm-project.git] / mlir / lib / Conversion / SPIRVToLLVM / SPIRVToLLVM.cpp
blobb11511f21d03d4d1b2c3bf5255aa1143d3b81836
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
29 #define DEBUG_TYPE "spirv-to-llvm-pattern"
31 using namespace mlir;
33 //===----------------------------------------------------------------------===//
34 // Utility functions
35 //===----------------------------------------------------------------------===//
37 /// Returns true if the given type is a signed integer or vector type.
38 static bool isSignedIntegerOrVector(Type type) {
39 if (type.isSignedInteger())
40 return true;
41 if (auto vecType = dyn_cast<VectorType>(type))
42 return vecType.getElementType().isSignedInteger();
43 return false;
46 /// Returns true if the given type is an unsigned integer or vector type
47 static bool isUnsignedIntegerOrVector(Type type) {
48 if (type.isUnsignedInteger())
49 return true;
50 if (auto vecType = dyn_cast<VectorType>(type))
51 return vecType.getElementType().isUnsignedInteger();
52 return false;
55 /// Returns the width of an integer or of the element type of an integer vector,
56 /// if applicable.
57 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
58 if (auto intType = dyn_cast<IntegerType>(type))
59 return intType.getWidth();
60 if (auto vecType = dyn_cast<VectorType>(type))
61 if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
62 return intType.getWidth();
63 return std::nullopt;
66 /// Returns the bit width of integer, float or vector of float or integer values
67 static unsigned getBitWidth(Type type) {
68 assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
69 "bitwidth is not supported for this type");
70 if (type.isIntOrFloat())
71 return type.getIntOrFloatBitWidth();
72 auto vecType = dyn_cast<VectorType>(type);
73 auto elementType = vecType.getElementType();
74 assert(elementType.isIntOrFloat() &&
75 "only integers and floats have a bitwidth");
76 return elementType.getIntOrFloatBitWidth();
79 /// Returns the bit width of LLVMType integer or vector.
80 static unsigned getLLVMTypeBitWidth(Type type) {
81 return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
82 ? LLVM::getVectorElementType(type)
83 : type))
84 .getWidth();
87 /// Creates `IntegerAttribute` with all bits set for given type
88 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
89 if (auto vecType = dyn_cast<VectorType>(type)) {
90 auto integerType = cast<IntegerType>(vecType.getElementType());
91 return builder.getIntegerAttr(integerType, -1);
93 auto integerType = cast<IntegerType>(type);
94 return builder.getIntegerAttr(integerType, -1);
97 /// Creates `llvm.mlir.constant` with all bits set for the given type.
98 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
99 PatternRewriter &rewriter) {
100 if (isa<VectorType>(srcType)) {
101 return rewriter.create<LLVM::ConstantOp>(
102 loc, dstType,
103 SplatElementsAttr::get(cast<ShapedType>(srcType),
104 minusOneIntegerAttribute(srcType, rewriter)));
106 return rewriter.create<LLVM::ConstantOp>(
107 loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
110 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
111 static Value createFPConstant(Location loc, Type srcType, Type dstType,
112 PatternRewriter &rewriter, double value) {
113 if (auto vecType = dyn_cast<VectorType>(srcType)) {
114 auto floatType = cast<FloatType>(vecType.getElementType());
115 return rewriter.create<LLVM::ConstantOp>(
116 loc, dstType,
117 SplatElementsAttr::get(vecType,
118 rewriter.getFloatAttr(floatType, value)));
120 auto floatType = cast<FloatType>(srcType);
121 return rewriter.create<LLVM::ConstantOp>(
122 loc, dstType, rewriter.getFloatAttr(floatType, value));
125 /// Utility function for bitfield ops:
126 /// - `BitFieldInsert`
127 /// - `BitFieldSExtract`
128 /// - `BitFieldUExtract`
129 /// Truncates or extends the value. If the bitwidth of the value is the same as
130 /// `llvmType` bitwidth, the value remains unchanged.
131 static Value optionallyTruncateOrExtend(Location loc, Value value,
132 Type llvmType,
133 PatternRewriter &rewriter) {
134 auto srcType = value.getType();
135 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
136 unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
137 ? getLLVMTypeBitWidth(srcType)
138 : getBitWidth(srcType);
140 if (valueBitWidth < targetBitWidth)
141 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
142 // If the bit widths of `Count` and `Offset` are greater than the bit width
143 // of the target type, they are truncated. Truncation is safe since `Count`
144 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
145 // both values can be expressed in 8 bits.
146 if (valueBitWidth > targetBitWidth)
147 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
148 return value;
151 /// Broadcasts the value to vector with `numElements` number of elements.
152 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
153 const TypeConverter &typeConverter,
154 ConversionPatternRewriter &rewriter) {
155 auto vectorType = VectorType::get(numElements, toBroadcast.getType());
156 auto llvmVectorType = typeConverter.convertType(vectorType);
157 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
158 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
159 for (unsigned i = 0; i < numElements; ++i) {
160 auto index = rewriter.create<LLVM::ConstantOp>(
161 loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
162 broadcasted = rewriter.create<LLVM::InsertElementOp>(
163 loc, llvmVectorType, broadcasted, toBroadcast, index);
165 return broadcasted;
168 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
169 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
170 const TypeConverter &typeConverter,
171 ConversionPatternRewriter &rewriter) {
172 if (auto vectorType = dyn_cast<VectorType>(srcType)) {
173 unsigned numElements = vectorType.getNumElements();
174 return broadcast(loc, value, numElements, typeConverter, rewriter);
176 return value;
179 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
180 /// `BitFieldUExtract`.
181 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
182 /// a vector type, construct a vector that has:
183 /// - same number of elements as `Base`
184 /// - each element has the type that is the same as the type of `Offset` or
185 /// `Count`
186 /// - each element has the same value as `Offset` or `Count`
187 /// Then cast `Offset` and `Count` if their bit width is different
188 /// from `Base` bit width.
189 static Value processCountOrOffset(Location loc, Value value, Type srcType,
190 Type dstType, const TypeConverter &converter,
191 ConversionPatternRewriter &rewriter) {
192 Value broadcasted =
193 optionallyBroadcast(loc, value, srcType, converter, rewriter);
194 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
197 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
198 /// offset to LLVM struct. Otherwise, the conversion is not supported.
199 static Type convertStructTypeWithOffset(spirv::StructType type,
200 const TypeConverter &converter) {
201 if (type != VulkanLayoutUtils::decorateType(type))
202 return nullptr;
204 SmallVector<Type> elementsVector;
205 if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
206 return nullptr;
207 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
208 /*isPacked=*/false);
211 /// Converts SPIR-V struct with no offset to packed LLVM struct.
212 static Type convertStructTypePacked(spirv::StructType type,
213 const TypeConverter &converter) {
214 SmallVector<Type> elementsVector;
215 if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
216 return nullptr;
217 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
218 /*isPacked=*/true);
221 /// Creates LLVM dialect constant with the given value.
222 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
223 unsigned value) {
224 return rewriter.create<LLVM::ConstantOp>(
225 loc, IntegerType::get(rewriter.getContext(), 32),
226 rewriter.getIntegerAttr(rewriter.getI32Type(), value));
229 /// Utility for `spirv.Load` and `spirv.Store` conversion.
230 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
231 ConversionPatternRewriter &rewriter,
232 const TypeConverter &typeConverter,
233 unsigned alignment, bool isVolatile,
234 bool isNonTemporal) {
235 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
236 auto dstType = typeConverter.convertType(loadOp.getType());
237 if (!dstType)
238 return rewriter.notifyMatchFailure(op, "type conversion failed");
239 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
240 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
241 isVolatile, isNonTemporal);
242 return success();
244 auto storeOp = cast<spirv::StoreOp>(op);
245 spirv::StoreOpAdaptor adaptor(operands);
246 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
247 adaptor.getPtr(), alignment,
248 isVolatile, isNonTemporal);
249 return success();
252 //===----------------------------------------------------------------------===//
253 // Type conversion
254 //===----------------------------------------------------------------------===//
256 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
257 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
258 /// when converting ops that manipulate array types.
259 static std::optional<Type> convertArrayType(spirv::ArrayType type,
260 TypeConverter &converter) {
261 unsigned stride = type.getArrayStride();
262 Type elementType = type.getElementType();
263 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
264 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
265 return std::nullopt;
267 auto llvmElementType = converter.convertType(elementType);
268 unsigned numElements = type.getNumElements();
269 return LLVM::LLVMArrayType::get(llvmElementType, numElements);
272 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
273 /// modelled at the moment.
274 static Type convertPointerType(spirv::PointerType type,
275 const TypeConverter &converter,
276 spirv::ClientAPI clientAPI) {
277 unsigned addressSpace =
278 storageClassToAddressSpace(clientAPI, type.getStorageClass());
279 return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
282 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
283 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
284 /// no modelling of array stride at the moment.
285 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
286 TypeConverter &converter) {
287 if (type.getArrayStride() != 0)
288 return std::nullopt;
289 auto elementType = converter.convertType(type.getElementType());
290 return LLVM::LLVMArrayType::get(elementType, 0);
293 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
294 /// member decorations. Also, only natural offset is supported.
295 static Type convertStructType(spirv::StructType type,
296 const TypeConverter &converter) {
297 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
298 type.getMemberDecorations(memberDecorations);
299 if (!memberDecorations.empty())
300 return nullptr;
301 if (type.hasOffset())
302 return convertStructTypeWithOffset(type, converter);
303 return convertStructTypePacked(type, converter);
306 //===----------------------------------------------------------------------===//
307 // Operation conversion
308 //===----------------------------------------------------------------------===//
310 namespace {
312 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
313 public:
314 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
316 LogicalResult
317 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter) const override {
319 auto dstType =
320 getTypeConverter()->convertType(op.getComponentPtr().getType());
321 if (!dstType)
322 return rewriter.notifyMatchFailure(op, "type conversion failed");
323 // To use GEP we need to add a first 0 index to go through the pointer.
324 auto indices = llvm::to_vector<4>(adaptor.getIndices());
325 Type indexType = op.getIndices().front().getType();
326 auto llvmIndexType = getTypeConverter()->convertType(indexType);
327 if (!llvmIndexType)
328 return rewriter.notifyMatchFailure(op, "type conversion failed");
329 Value zero = rewriter.create<LLVM::ConstantOp>(
330 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
331 indices.insert(indices.begin(), zero);
333 auto elementType = getTypeConverter()->convertType(
334 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
335 if (!elementType)
336 return rewriter.notifyMatchFailure(op, "type conversion failed");
337 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
338 adaptor.getBasePtr(), indices);
339 return success();
343 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
344 public:
345 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
347 LogicalResult
348 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override {
350 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
351 if (!dstType)
352 return rewriter.notifyMatchFailure(op, "type conversion failed");
353 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
354 op.getVariable());
355 return success();
359 class BitFieldInsertPattern
360 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
361 public:
362 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
364 LogicalResult
365 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
366 ConversionPatternRewriter &rewriter) const override {
367 auto srcType = op.getType();
368 auto dstType = getTypeConverter()->convertType(srcType);
369 if (!dstType)
370 return rewriter.notifyMatchFailure(op, "type conversion failed");
371 Location loc = op.getLoc();
373 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
374 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
375 *getTypeConverter(), rewriter);
376 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
377 *getTypeConverter(), rewriter);
379 // Create a mask with bits set outside [Offset, Offset + Count - 1].
380 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
381 Value maskShiftedByCount =
382 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
383 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
384 maskShiftedByCount, minusOne);
385 Value maskShiftedByCountAndOffset =
386 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
387 Value mask = rewriter.create<LLVM::XOrOp>(
388 loc, dstType, maskShiftedByCountAndOffset, minusOne);
390 // Extract unchanged bits from the `Base` that are outside of
391 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
392 Value baseAndMask =
393 rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
394 Value insertShiftedByOffset =
395 rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
396 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
397 insertShiftedByOffset);
398 return success();
402 /// Converts SPIR-V ConstantOp with scalar or vector type.
403 class ConstantScalarAndVectorPattern
404 : public SPIRVToLLVMConversion<spirv::ConstantOp> {
405 public:
406 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
408 LogicalResult
409 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
410 ConversionPatternRewriter &rewriter) const override {
411 auto srcType = constOp.getType();
412 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
413 return failure();
415 auto dstType = getTypeConverter()->convertType(srcType);
416 if (!dstType)
417 return rewriter.notifyMatchFailure(constOp, "type conversion failed");
419 // SPIR-V constant can be a signed/unsigned integer, which has to be
420 // casted to signless integer when converting to LLVM dialect. Removing the
421 // sign bit may have unexpected behaviour. However, it is better to handle
422 // it case-by-case, given that the purpose of the conversion is not to
423 // cover all possible corner cases.
424 if (isSignedIntegerOrVector(srcType) ||
425 isUnsignedIntegerOrVector(srcType)) {
426 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
428 if (isa<VectorType>(srcType)) {
429 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
430 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
431 constOp, dstType,
432 dstElementsAttr.mapValues(
433 signlessType, [&](const APInt &value) { return value; }));
434 return success();
436 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
437 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
438 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
439 return success();
441 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
442 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
443 return success();
447 class BitFieldSExtractPattern
448 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
449 public:
450 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
452 LogicalResult
453 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
454 ConversionPatternRewriter &rewriter) const override {
455 auto srcType = op.getType();
456 auto dstType = getTypeConverter()->convertType(srcType);
457 if (!dstType)
458 return rewriter.notifyMatchFailure(op, "type conversion failed");
459 Location loc = op.getLoc();
461 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
462 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
463 *getTypeConverter(), rewriter);
464 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
465 *getTypeConverter(), rewriter);
467 // Create a constant that holds the size of the `Base`.
468 IntegerType integerType;
469 if (auto vecType = dyn_cast<VectorType>(srcType))
470 integerType = cast<IntegerType>(vecType.getElementType());
471 else
472 integerType = cast<IntegerType>(srcType);
474 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
475 Value size =
476 isa<VectorType>(srcType)
477 ? rewriter.create<LLVM::ConstantOp>(
478 loc, dstType,
479 SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
480 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
482 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
483 // at Offset + Count - 1 is the most significant bit now.
484 Value countPlusOffset =
485 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
486 Value amountToShiftLeft =
487 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
488 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
489 loc, dstType, op.getBase(), amountToShiftLeft);
491 // Shift the result right, filling the bits with the sign bit.
492 Value amountToShiftRight =
493 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
494 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
495 amountToShiftRight);
496 return success();
500 class BitFieldUExtractPattern
501 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
502 public:
503 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
505 LogicalResult
506 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
507 ConversionPatternRewriter &rewriter) const override {
508 auto srcType = op.getType();
509 auto dstType = getTypeConverter()->convertType(srcType);
510 if (!dstType)
511 return rewriter.notifyMatchFailure(op, "type conversion failed");
512 Location loc = op.getLoc();
514 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
515 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
516 *getTypeConverter(), rewriter);
517 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
518 *getTypeConverter(), rewriter);
520 // Create a mask with bits set at [0, Count - 1].
521 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
522 Value maskShiftedByCount =
523 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
524 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
525 minusOne);
527 // Shift `Base` by `Offset` and apply the mask on it.
528 Value shiftedBase =
529 rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
530 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
531 return success();
535 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
536 public:
537 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
539 LogicalResult
540 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
541 ConversionPatternRewriter &rewriter) const override {
542 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
543 branchOp.getTarget());
544 return success();
548 class BranchConditionalConversionPattern
549 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
550 public:
551 using SPIRVToLLVMConversion<
552 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
554 LogicalResult
555 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
556 ConversionPatternRewriter &rewriter) const override {
557 // If branch weights exist, map them to 32-bit integer vector.
558 DenseI32ArrayAttr branchWeights = nullptr;
559 if (auto weights = op.getBranchWeights()) {
560 SmallVector<int32_t> weightValues;
561 for (auto weight : weights->getAsRange<IntegerAttr>())
562 weightValues.push_back(weight.getInt());
563 branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
566 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
567 op, op.getCondition(), op.getTrueBlockArguments(),
568 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
569 op.getFalseBlock());
570 return success();
574 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
575 /// type is an aggregate type (struct or array). Otherwise, converts to
576 /// `llvm.extractelement` that operates on vectors.
577 class CompositeExtractPattern
578 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
579 public:
580 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
582 LogicalResult
583 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
584 ConversionPatternRewriter &rewriter) const override {
585 auto dstType = this->getTypeConverter()->convertType(op.getType());
586 if (!dstType)
587 return rewriter.notifyMatchFailure(op, "type conversion failed");
589 Type containerType = op.getComposite().getType();
590 if (isa<VectorType>(containerType)) {
591 Location loc = op.getLoc();
592 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
593 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
594 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
595 op, dstType, adaptor.getComposite(), index);
596 return success();
599 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
600 op, adaptor.getComposite(),
601 LLVM::convertArrayToIndices(op.getIndices()));
602 return success();
606 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
607 /// type is an aggregate type (struct or array). Otherwise, converts to
608 /// `llvm.insertelement` that operates on vectors.
609 class CompositeInsertPattern
610 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
611 public:
612 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
614 LogicalResult
615 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
616 ConversionPatternRewriter &rewriter) const override {
617 auto dstType = this->getTypeConverter()->convertType(op.getType());
618 if (!dstType)
619 return rewriter.notifyMatchFailure(op, "type conversion failed");
621 Type containerType = op.getComposite().getType();
622 if (isa<VectorType>(containerType)) {
623 Location loc = op.getLoc();
624 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
625 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
626 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
627 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
628 return success();
631 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
632 op, adaptor.getComposite(), adaptor.getObject(),
633 LLVM::convertArrayToIndices(op.getIndices()));
634 return success();
638 /// Converts SPIR-V operations that have straightforward LLVM equivalent
639 /// into LLVM dialect operations.
640 template <typename SPIRVOp, typename LLVMOp>
641 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
642 public:
643 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
645 LogicalResult
646 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
647 ConversionPatternRewriter &rewriter) const override {
648 auto dstType = this->getTypeConverter()->convertType(op.getType());
649 if (!dstType)
650 return rewriter.notifyMatchFailure(op, "type conversion failed");
651 rewriter.template replaceOpWithNewOp<LLVMOp>(
652 op, dstType, adaptor.getOperands(), op->getAttrs());
653 return success();
657 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
658 /// execution mode information.
659 class ExecutionModePattern
660 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
661 public:
662 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
664 LogicalResult
665 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
666 ConversionPatternRewriter &rewriter) const override {
667 // First, create the global struct's name that would be associated with
668 // this entry point's execution mode. We set it to be:
669 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
670 ModuleOp module = op->getParentOfType<ModuleOp>();
671 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
672 std::string moduleName;
673 if (module.getName().has_value())
674 moduleName = "_" + module.getName()->str();
675 else
676 moduleName = "";
677 std::string executionModeInfoName = llvm::formatv(
678 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
679 static_cast<uint32_t>(executionModeAttr.getValue()));
681 MLIRContext *context = rewriter.getContext();
682 OpBuilder::InsertionGuard guard(rewriter);
683 rewriter.setInsertionPointToStart(module.getBody());
685 // Create a struct type, corresponding to the C struct below.
686 // struct {
687 // int32_t executionMode;
688 // int32_t values[]; // optional values
689 // };
690 auto llvmI32Type = IntegerType::get(context, 32);
691 SmallVector<Type, 2> fields;
692 fields.push_back(llvmI32Type);
693 ArrayAttr values = op.getValues();
694 if (!values.empty()) {
695 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
696 fields.push_back(arrayType);
698 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
700 // Create `llvm.mlir.global` with initializer region containing one block.
701 auto global = rewriter.create<LLVM::GlobalOp>(
702 UnknownLoc::get(context), structType, /*isConstant=*/true,
703 LLVM::Linkage::External, executionModeInfoName, Attribute(),
704 /*alignment=*/0);
705 Location loc = global.getLoc();
706 Region &region = global.getInitializerRegion();
707 Block *block = rewriter.createBlock(&region);
709 // Initialize the struct and set the execution mode value.
710 rewriter.setInsertionPointToStart(block);
711 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
712 Value executionMode = rewriter.create<LLVM::ConstantOp>(
713 loc, llvmI32Type,
714 rewriter.getI32IntegerAttr(
715 static_cast<uint32_t>(executionModeAttr.getValue())));
716 structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
717 executionMode, 0);
719 // Insert extra operands if they exist into execution mode info struct.
720 for (unsigned i = 0, e = values.size(); i < e; ++i) {
721 auto attr = values.getValue()[i];
722 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
723 structValue = rewriter.create<LLVM::InsertValueOp>(
724 loc, structValue, entry, ArrayRef<int64_t>({1, i}));
726 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
727 rewriter.eraseOp(op);
728 return success();
732 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
733 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
734 /// value. This difference is handled by `spirv.mlir.addressof` and
735 /// `llvm.mlir.addressof`ops that both return a pointer.
736 class GlobalVariablePattern
737 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
738 public:
739 template <typename... Args>
740 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
741 : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
742 std::forward<Args>(args)...),
743 clientAPI(clientAPI) {}
745 LogicalResult
746 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter) const override {
748 // Currently, there is no support of initialization with a constant value in
749 // SPIR-V dialect. Specialization constants are not considered as well.
750 if (op.getInitializer())
751 return failure();
753 auto srcType = cast<spirv::PointerType>(op.getType());
754 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
755 if (!dstType)
756 return rewriter.notifyMatchFailure(op, "type conversion failed");
758 // Limit conversion to the current invocation only or `StorageBuffer`
759 // required by SPIR-V runner.
760 // This is okay because multiple invocations are not supported yet.
761 auto storageClass = srcType.getStorageClass();
762 switch (storageClass) {
763 case spirv::StorageClass::Input:
764 case spirv::StorageClass::Private:
765 case spirv::StorageClass::Output:
766 case spirv::StorageClass::StorageBuffer:
767 case spirv::StorageClass::UniformConstant:
768 break;
769 default:
770 return failure();
773 // LLVM dialect spec: "If the global value is a constant, storing into it is
774 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
775 // storage class that is read-only.
776 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
777 (storageClass == spirv::StorageClass::UniformConstant);
778 // SPIR-V spec: "By default, functions and global variables are private to a
779 // module and cannot be accessed by other modules. However, a module may be
780 // written to export or import functions and global (module scope)
781 // variables.". Therefore, map 'Private' storage class to private linkage,
782 // 'Input' and 'Output' to external linkage.
783 auto linkage = storageClass == spirv::StorageClass::Private
784 ? LLVM::Linkage::Private
785 : LLVM::Linkage::External;
786 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787 op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
788 /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
790 // Attach location attribute if applicable
791 if (op.getLocationAttr())
792 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
794 return success();
797 private:
798 spirv::ClientAPI clientAPI;
801 /// Converts SPIR-V cast ops that do not have straightforward LLVM
802 /// equivalent in LLVM dialect.
803 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
804 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
805 public:
806 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
808 LogicalResult
809 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
810 ConversionPatternRewriter &rewriter) const override {
812 Type fromType = op.getOperand().getType();
813 Type toType = op.getType();
815 auto dstType = this->getTypeConverter()->convertType(toType);
816 if (!dstType)
817 return rewriter.notifyMatchFailure(op, "type conversion failed");
819 if (getBitWidth(fromType) < getBitWidth(toType)) {
820 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821 adaptor.getOperands());
822 return success();
824 if (getBitWidth(fromType) > getBitWidth(toType)) {
825 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826 adaptor.getOperands());
827 return success();
829 return failure();
833 class FunctionCallPattern
834 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
835 public:
836 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
838 LogicalResult
839 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840 ConversionPatternRewriter &rewriter) const override {
841 if (callOp.getNumResults() == 0) {
842 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
843 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
844 newOp.getProperties().operandSegmentSizes = {
845 static_cast<int32_t>(adaptor.getOperands().size()), 0};
846 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
847 return success();
850 // Function returns a single result.
851 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
852 if (!dstType)
853 return rewriter.notifyMatchFailure(callOp, "type conversion failed");
854 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856 newOp.getProperties().operandSegmentSizes = {
857 static_cast<int32_t>(adaptor.getOperands().size()), 0};
858 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
859 return success();
863 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
864 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
865 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
866 public:
867 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
869 LogicalResult
870 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
871 ConversionPatternRewriter &rewriter) const override {
873 auto dstType = this->getTypeConverter()->convertType(op.getType());
874 if (!dstType)
875 return rewriter.notifyMatchFailure(op, "type conversion failed");
877 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878 op, dstType, predicate, op.getOperand1(), op.getOperand2());
879 return success();
883 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
884 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
885 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
886 public:
887 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
889 LogicalResult
890 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
891 ConversionPatternRewriter &rewriter) const override {
893 auto dstType = this->getTypeConverter()->convertType(op.getType());
894 if (!dstType)
895 return rewriter.notifyMatchFailure(op, "type conversion failed");
897 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898 op, dstType, predicate, op.getOperand1(), op.getOperand2());
899 return success();
903 class InverseSqrtPattern
904 : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
905 public:
906 using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
908 LogicalResult
909 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910 ConversionPatternRewriter &rewriter) const override {
911 auto srcType = op.getType();
912 auto dstType = getTypeConverter()->convertType(srcType);
913 if (!dstType)
914 return rewriter.notifyMatchFailure(op, "type conversion failed");
916 Location loc = op.getLoc();
917 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
918 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
919 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
920 return success();
924 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
925 template <typename SPIRVOp>
926 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
927 public:
928 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
930 LogicalResult
931 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
932 ConversionPatternRewriter &rewriter) const override {
933 if (!op.getMemoryAccess()) {
934 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
935 *this->getTypeConverter(), /*alignment=*/0,
936 /*isVolatile=*/false,
937 /*isNonTemporal=*/false);
939 auto memoryAccess = *op.getMemoryAccess();
940 switch (memoryAccess) {
941 case spirv::MemoryAccess::Aligned:
942 case spirv::MemoryAccess::None:
943 case spirv::MemoryAccess::Nontemporal:
944 case spirv::MemoryAccess::Volatile: {
945 unsigned alignment =
946 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
950 *this->getTypeConverter(), alignment,
951 isVolatile, isNonTemporal);
953 default:
954 // There is no support of other memory access attributes.
955 return failure();
960 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
961 template <typename SPIRVOp>
962 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
963 public:
964 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
966 LogicalResult
967 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
968 ConversionPatternRewriter &rewriter) const override {
969 auto srcType = notOp.getType();
970 auto dstType = this->getTypeConverter()->convertType(srcType);
971 if (!dstType)
972 return rewriter.notifyMatchFailure(notOp, "type conversion failed");
974 Location loc = notOp.getLoc();
975 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
976 auto mask =
977 isa<VectorType>(srcType)
978 ? rewriter.create<LLVM::ConstantOp>(
979 loc, dstType,
980 SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
982 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983 notOp.getOperand(), mask);
984 return success();
988 /// A template pattern that erases the given `SPIRVOp`.
989 template <typename SPIRVOp>
990 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
991 public:
992 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
994 LogicalResult
995 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
996 ConversionPatternRewriter &rewriter) const override {
997 rewriter.eraseOp(op);
998 return success();
1002 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1003 public:
1004 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1006 LogicalResult
1007 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008 ConversionPatternRewriter &rewriter) const override {
1009 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1010 ArrayRef<Value>());
1011 return success();
1015 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1016 public:
1017 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1019 LogicalResult
1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter) const override {
1022 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023 adaptor.getOperands());
1024 return success();
1028 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1029 StringRef name,
1030 ArrayRef<Type> paramTypes,
1031 Type resultType,
1032 bool convergent = true) {
1033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1034 SymbolTable::lookupSymbolIn(symbolTable, name));
1035 if (func)
1036 return func;
1038 OpBuilder b(symbolTable->getRegion(0));
1039 func = b.create<LLVM::LLVMFuncOp>(
1040 symbolTable->getLoc(), name,
1041 LLVM::LLVMFunctionType::get(resultType, paramTypes));
1042 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043 func.setConvergent(convergent);
1044 func.setNoUnwind(true);
1045 func.setWillReturn(true);
1046 return func;
1049 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1050 LLVM::LLVMFuncOp func,
1051 ValueRange args) {
1052 auto call = builder.create<LLVM::CallOp>(loc, func, args);
1053 call.setCConv(func.getCConv());
1054 call.setConvergentAttr(func.getConvergentAttr());
1055 call.setNoUnwindAttr(func.getNoUnwindAttr());
1056 call.setWillReturnAttr(func.getWillReturnAttr());
1057 return call;
1060 template <typename BarrierOpTy>
1061 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1062 public:
1063 using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1065 using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1067 static constexpr StringRef getFuncName();
1069 LogicalResult
1070 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071 ConversionPatternRewriter &rewriter) const override {
1072 constexpr StringRef funcName = getFuncName();
1073 Operation *symbolTable =
1074 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1076 Type i32 = rewriter.getI32Type();
1078 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1079 LLVM::LLVMFuncOp func =
1080 lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1082 Location loc = controlBarrierOp->getLoc();
1083 Value execution = rewriter.create<LLVM::ConstantOp>(
1084 loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1085 Value memory = rewriter.create<LLVM::ConstantOp>(
1086 loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1087 Value semantics = rewriter.create<LLVM::ConstantOp>(
1088 loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1090 auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1091 {execution, memory, semantics});
1093 rewriter.replaceOp(controlBarrierOp, call);
1094 return success();
1098 namespace {
1100 StringRef getTypeMangling(Type type, bool isSigned) {
1101 return llvm::TypeSwitch<Type, StringRef>(type)
1102 .Case<Float16Type>([](auto) { return "Dh"; })
1103 .Case<Float32Type>([](auto) { return "f"; })
1104 .Case<Float64Type>([](auto) { return "d"; })
1105 .Case<IntegerType>([isSigned](IntegerType intTy) {
1106 switch (intTy.getWidth()) {
1107 case 1:
1108 return "b";
1109 case 8:
1110 return (isSigned) ? "a" : "c";
1111 case 16:
1112 return (isSigned) ? "s" : "t";
1113 case 32:
1114 return (isSigned) ? "i" : "j";
1115 case 64:
1116 return (isSigned) ? "l" : "m";
1117 default:
1118 llvm_unreachable("Unsupported integer width");
1121 .Default([](auto) {
1122 llvm_unreachable("No mangling defined");
1123 return "";
1127 template <typename ReduceOp>
1128 constexpr StringLiteral getGroupFuncName();
1130 template <>
1131 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1132 return "_Z17__spirv_GroupIAddii";
1134 template <>
1135 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1136 return "_Z17__spirv_GroupFAddii";
1138 template <>
1139 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1140 return "_Z17__spirv_GroupSMinii";
1142 template <>
1143 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1144 return "_Z17__spirv_GroupUMinii";
1146 template <>
1147 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1148 return "_Z17__spirv_GroupFMinii";
1150 template <>
1151 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1152 return "_Z17__spirv_GroupSMaxii";
1154 template <>
1155 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1156 return "_Z17__spirv_GroupUMaxii";
1158 template <>
1159 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1160 return "_Z17__spirv_GroupFMaxii";
1162 template <>
1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1164 return "_Z27__spirv_GroupNonUniformIAddii";
1166 template <>
1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1168 return "_Z27__spirv_GroupNonUniformFAddii";
1170 template <>
1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1172 return "_Z27__spirv_GroupNonUniformIMulii";
1174 template <>
1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1176 return "_Z27__spirv_GroupNonUniformFMulii";
1178 template <>
1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1180 return "_Z27__spirv_GroupNonUniformSMinii";
1182 template <>
1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1184 return "_Z27__spirv_GroupNonUniformUMinii";
1186 template <>
1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1188 return "_Z27__spirv_GroupNonUniformFMinii";
1190 template <>
1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1192 return "_Z27__spirv_GroupNonUniformSMaxii";
1194 template <>
1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1196 return "_Z27__spirv_GroupNonUniformUMaxii";
1198 template <>
1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1200 return "_Z27__spirv_GroupNonUniformFMaxii";
1202 template <>
1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1204 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1206 template <>
1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1208 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1210 template <>
1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1212 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1214 template <>
1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1216 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1218 template <>
1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1220 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1222 template <>
1223 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1224 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1226 } // namespace
1228 template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1229 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1230 public:
1231 using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1233 LogicalResult
1234 matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1235 ConversionPatternRewriter &rewriter) const override {
1237 Type retTy = op.getResult().getType();
1238 if (!retTy.isIntOrFloat()) {
1239 return failure();
1241 SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1242 funcName += getTypeMangling(retTy, false);
1244 Type i32Ty = rewriter.getI32Type();
1245 SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1246 if constexpr (NonUniform) {
1247 if (adaptor.getClusterSize()) {
1248 funcName += "j";
1249 paramTypes.push_back(i32Ty);
1253 Operation *symbolTable =
1254 op->template getParentWithTrait<OpTrait::SymbolTable>();
1256 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
1257 symbolTable, funcName, paramTypes, retTy, !NonUniform);
1259 Location loc = op.getLoc();
1260 Value scope = rewriter.create<LLVM::ConstantOp>(
1261 loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1262 Value groupOp = rewriter.create<LLVM::ConstantOp>(
1263 loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1264 SmallVector<Value> operands{scope, groupOp};
1265 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1267 auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1268 rewriter.replaceOp(op, call);
1269 return success();
1273 template <>
1274 constexpr StringRef
1275 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1276 return "_Z22__spirv_ControlBarrieriii";
1279 template <>
1280 constexpr StringRef
1281 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1282 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1285 template <>
1286 constexpr StringRef
1287 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1288 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1291 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1292 /// should be reachable for conversion to succeed. The structure of the loop in
1293 /// LLVM dialect will be the following:
1295 /// +------------------------------------+
1296 /// | <code before spirv.mlir.loop> |
1297 /// | llvm.br ^header |
1298 /// +------------------------------------+
1299 /// |
1300 /// +----------------+ |
1301 /// | | |
1302 /// | V V
1303 /// | +------------------------------------+
1304 /// | | ^header: |
1305 /// | | <header code> |
1306 /// | | llvm.cond_br %cond, ^body, ^exit |
1307 /// | +------------------------------------+
1308 /// | |
1309 /// | |----------------------+
1310 /// | | |
1311 /// | V |
1312 /// | +------------------------------------+ |
1313 /// | | ^body: | |
1314 /// | | <body code> | |
1315 /// | | llvm.br ^continue | |
1316 /// | +------------------------------------+ |
1317 /// | | |
1318 /// | V |
1319 /// | +------------------------------------+ |
1320 /// | | ^continue: | |
1321 /// | | <continue code> | |
1322 /// | | llvm.br ^header | |
1323 /// | +------------------------------------+ |
1324 /// | | |
1325 /// +---------------+ +----------------------+
1326 /// |
1327 /// V
1328 /// +------------------------------------+
1329 /// | ^exit: |
1330 /// | llvm.br ^remaining |
1331 /// +------------------------------------+
1332 /// |
1333 /// V
1334 /// +------------------------------------+
1335 /// | ^remaining: |
1336 /// | <code after spirv.mlir.loop> |
1337 /// +------------------------------------+
1339 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1340 public:
1341 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1343 LogicalResult
1344 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1345 ConversionPatternRewriter &rewriter) const override {
1346 // There is no support of loop control at the moment.
1347 if (loopOp.getLoopControl() != spirv::LoopControl::None)
1348 return failure();
1350 // `spirv.mlir.loop` with empty region is redundant and should be erased.
1351 if (loopOp.getBody().empty()) {
1352 rewriter.eraseOp(loopOp);
1353 return success();
1356 Location loc = loopOp.getLoc();
1358 // Split the current block after `spirv.mlir.loop`. The remaining ops will
1359 // be used in `endBlock`.
1360 Block *currentBlock = rewriter.getBlock();
1361 auto position = Block::iterator(loopOp);
1362 Block *endBlock = rewriter.splitBlock(currentBlock, position);
1364 // Remove entry block and create a branch in the current block going to the
1365 // header block.
1366 Block *entryBlock = loopOp.getEntryBlock();
1367 assert(entryBlock->getOperations().size() == 1);
1368 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1369 if (!brOp)
1370 return failure();
1371 Block *headerBlock = loopOp.getHeaderBlock();
1372 rewriter.setInsertionPointToEnd(currentBlock);
1373 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1374 rewriter.eraseBlock(entryBlock);
1376 // Branch from merge block to end block.
1377 Block *mergeBlock = loopOp.getMergeBlock();
1378 Operation *terminator = mergeBlock->getTerminator();
1379 ValueRange terminatorOperands = terminator->getOperands();
1380 rewriter.setInsertionPointToEnd(mergeBlock);
1381 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1383 rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1384 rewriter.replaceOp(loopOp, endBlock->getArguments());
1385 return success();
1389 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1390 /// block. All blocks within selection should be reachable for conversion to
1391 /// succeed.
1392 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1393 public:
1394 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1396 LogicalResult
1397 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1398 ConversionPatternRewriter &rewriter) const override {
1399 // There is no support for `Flatten` or `DontFlatten` selection control at
1400 // the moment. This are just compiler hints and can be performed during the
1401 // optimization passes.
1402 if (op.getSelectionControl() != spirv::SelectionControl::None)
1403 return failure();
1405 // `spirv.mlir.selection` should have at least two blocks: one selection
1406 // header block and one merge block. If no blocks are present, or control
1407 // flow branches straight to merge block (two blocks are present), the op is
1408 // redundant and it is erased.
1409 if (op.getBody().getBlocks().size() <= 2) {
1410 rewriter.eraseOp(op);
1411 return success();
1414 Location loc = op.getLoc();
1416 // Split the current block after `spirv.mlir.selection`. The remaining ops
1417 // will be used in `continueBlock`.
1418 auto *currentBlock = rewriter.getInsertionBlock();
1419 rewriter.setInsertionPointAfter(op);
1420 auto position = rewriter.getInsertionPoint();
1421 auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1423 // Extract conditional branch information from the header block. By SPIR-V
1424 // dialect spec, it should contain `spirv.BranchConditional` or
1425 // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1426 // moment in the SPIR-V dialect. Remove this block when finished.
1427 auto *headerBlock = op.getHeaderBlock();
1428 assert(headerBlock->getOperations().size() == 1);
1429 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1430 headerBlock->getOperations().front());
1431 if (!condBrOp)
1432 return failure();
1433 rewriter.eraseBlock(headerBlock);
1435 // Branch from merge block to continue block.
1436 auto *mergeBlock = op.getMergeBlock();
1437 Operation *terminator = mergeBlock->getTerminator();
1438 ValueRange terminatorOperands = terminator->getOperands();
1439 rewriter.setInsertionPointToEnd(mergeBlock);
1440 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1442 // Link current block to `true` and `false` blocks within the selection.
1443 Block *trueBlock = condBrOp.getTrueBlock();
1444 Block *falseBlock = condBrOp.getFalseBlock();
1445 rewriter.setInsertionPointToEnd(currentBlock);
1446 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1447 condBrOp.getTrueTargetOperands(),
1448 falseBlock,
1449 condBrOp.getFalseTargetOperands());
1451 rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1452 rewriter.replaceOp(op, continueBlock->getArguments());
1453 return success();
1457 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1458 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1459 /// `Shift` is zero or sign extended to match this specification. Cases when
1460 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1461 template <typename SPIRVOp, typename LLVMOp>
1462 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1463 public:
1464 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1466 LogicalResult
1467 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1468 ConversionPatternRewriter &rewriter) const override {
1470 auto dstType = this->getTypeConverter()->convertType(op.getType());
1471 if (!dstType)
1472 return rewriter.notifyMatchFailure(op, "type conversion failed");
1474 Type op1Type = op.getOperand1().getType();
1475 Type op2Type = op.getOperand2().getType();
1477 if (op1Type == op2Type) {
1478 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1479 adaptor.getOperands());
1480 return success();
1483 std::optional<uint64_t> dstTypeWidth =
1484 getIntegerOrVectorElementWidth(dstType);
1485 std::optional<uint64_t> op2TypeWidth =
1486 getIntegerOrVectorElementWidth(op2Type);
1488 if (!dstTypeWidth || !op2TypeWidth)
1489 return failure();
1491 Location loc = op.getLoc();
1492 Value extended;
1493 if (op2TypeWidth < dstTypeWidth) {
1494 if (isUnsignedIntegerOrVector(op2Type)) {
1495 extended = rewriter.template create<LLVM::ZExtOp>(
1496 loc, dstType, adaptor.getOperand2());
1497 } else {
1498 extended = rewriter.template create<LLVM::SExtOp>(
1499 loc, dstType, adaptor.getOperand2());
1501 } else if (op2TypeWidth == dstTypeWidth) {
1502 extended = adaptor.getOperand2();
1503 } else {
1504 return failure();
1507 Value result = rewriter.template create<LLVMOp>(
1508 loc, dstType, adaptor.getOperand1(), extended);
1509 rewriter.replaceOp(op, result);
1510 return success();
1514 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1515 public:
1516 using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1518 LogicalResult
1519 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1520 ConversionPatternRewriter &rewriter) const override {
1521 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1522 if (!dstType)
1523 return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1525 Location loc = tanOp.getLoc();
1526 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1527 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1528 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1529 return success();
1533 /// Convert `spirv.Tanh` to
1535 /// exp(2x) - 1
1536 /// -----------
1537 /// exp(2x) + 1
1539 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1540 public:
1541 using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1543 LogicalResult
1544 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1545 ConversionPatternRewriter &rewriter) const override {
1546 auto srcType = tanhOp.getType();
1547 auto dstType = getTypeConverter()->convertType(srcType);
1548 if (!dstType)
1549 return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1551 Location loc = tanhOp.getLoc();
1552 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1553 Value multiplied =
1554 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1555 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1556 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1557 Value numerator =
1558 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1559 Value denominator =
1560 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1561 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1562 denominator);
1563 return success();
1567 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1568 public:
1569 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1571 LogicalResult
1572 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1573 ConversionPatternRewriter &rewriter) const override {
1574 auto srcType = varOp.getType();
1575 // Initialization is supported for scalars and vectors only.
1576 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1577 auto init = varOp.getInitializer();
1578 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1579 return failure();
1581 auto dstType = getTypeConverter()->convertType(srcType);
1582 if (!dstType)
1583 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1585 Location loc = varOp.getLoc();
1586 Value size = createI32ConstantOf(loc, rewriter, 1);
1587 if (!init) {
1588 auto elementType = getTypeConverter()->convertType(pointerTo);
1589 if (!elementType)
1590 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1591 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1592 size);
1593 return success();
1595 auto elementType = getTypeConverter()->convertType(pointerTo);
1596 if (!elementType)
1597 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1598 Value allocated =
1599 rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1600 rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1601 rewriter.replaceOp(varOp, allocated);
1602 return success();
1606 //===----------------------------------------------------------------------===//
1607 // BitcastOp conversion
1608 //===----------------------------------------------------------------------===//
1610 class BitcastConversionPattern
1611 : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1612 public:
1613 using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
1615 LogicalResult
1616 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1617 ConversionPatternRewriter &rewriter) const override {
1618 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1619 if (!dstType)
1620 return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1622 // LLVM's opaque pointers do not require bitcasts.
1623 if (isa<LLVM::LLVMPointerType>(dstType)) {
1624 rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1625 return success();
1628 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1629 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1630 return success();
1634 //===----------------------------------------------------------------------===//
1635 // FuncOp conversion
1636 //===----------------------------------------------------------------------===//
1638 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1639 public:
1640 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1642 LogicalResult
1643 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1644 ConversionPatternRewriter &rewriter) const override {
1646 // Convert function signature. At the moment LLVMType converter is enough
1647 // for currently supported types.
1648 auto funcType = funcOp.getFunctionType();
1649 TypeConverter::SignatureConversion signatureConverter(
1650 funcType.getNumInputs());
1651 auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1652 ->convertFunctionSignature(
1653 funcType, /*isVariadic=*/false,
1654 /*useBarePtrCallConv=*/false, signatureConverter);
1655 if (!llvmType)
1656 return failure();
1658 // Create a new `LLVMFuncOp`
1659 Location loc = funcOp.getLoc();
1660 StringRef name = funcOp.getName();
1661 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1663 // Convert SPIR-V Function Control to equivalent LLVM function attribute
1664 MLIRContext *context = funcOp.getContext();
1665 switch (funcOp.getFunctionControl()) {
1666 case spirv::FunctionControl::Inline:
1667 newFuncOp.setAlwaysInline(true);
1668 break;
1669 case spirv::FunctionControl::DontInline:
1670 newFuncOp.setNoInline(true);
1671 break;
1673 #define DISPATCH(functionControl, llvmAttr) \
1674 case functionControl: \
1675 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1676 break;
1678 DISPATCH(spirv::FunctionControl::Pure,
1679 StringAttr::get(context, "readonly"));
1680 DISPATCH(spirv::FunctionControl::Const,
1681 StringAttr::get(context, "readnone"));
1683 #undef DISPATCH
1685 // Default: if `spirv::FunctionControl::None`, then no attributes are
1686 // needed.
1687 default:
1688 break;
1691 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1692 newFuncOp.end());
1693 if (failed(rewriter.convertRegionTypes(
1694 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1695 return failure();
1697 rewriter.eraseOp(funcOp);
1698 return success();
1702 //===----------------------------------------------------------------------===//
1703 // ModuleOp conversion
1704 //===----------------------------------------------------------------------===//
1706 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1707 public:
1708 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1710 LogicalResult
1711 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1712 ConversionPatternRewriter &rewriter) const override {
1714 auto newModuleOp =
1715 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1716 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1718 // Remove the terminator block that was automatically added by builder
1719 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1720 rewriter.eraseOp(spvModuleOp);
1721 return success();
1725 //===----------------------------------------------------------------------===//
1726 // VectorShuffleOp conversion
1727 //===----------------------------------------------------------------------===//
1729 class VectorShufflePattern
1730 : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1731 public:
1732 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1733 LogicalResult
1734 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1735 ConversionPatternRewriter &rewriter) const override {
1736 Location loc = op.getLoc();
1737 auto components = adaptor.getComponents();
1738 auto vector1 = adaptor.getVector1();
1739 auto vector2 = adaptor.getVector2();
1740 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1741 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1742 if (vector1Size == vector2Size) {
1743 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1744 op, vector1, vector2,
1745 LLVM::convertArrayToIndices<int32_t>(components));
1746 return success();
1749 auto dstType = getTypeConverter()->convertType(op.getType());
1750 if (!dstType)
1751 return rewriter.notifyMatchFailure(op, "type conversion failed");
1752 auto scalarType = cast<VectorType>(dstType).getElementType();
1753 auto componentsArray = components.getValue();
1754 auto *context = rewriter.getContext();
1755 auto llvmI32Type = IntegerType::get(context, 32);
1756 Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1757 for (unsigned i = 0; i < componentsArray.size(); i++) {
1758 if (!isa<IntegerAttr>(componentsArray[i]))
1759 return op.emitError("unable to support non-constant component");
1761 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1762 if (indexVal == -1)
1763 continue;
1765 int offsetVal = 0;
1766 Value baseVector = vector1;
1767 if (indexVal >= vector1Size) {
1768 offsetVal = vector1Size;
1769 baseVector = vector2;
1772 Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1773 loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1774 Value index = rewriter.create<LLVM::ConstantOp>(
1775 loc, llvmI32Type,
1776 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1778 auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1779 loc, scalarType, baseVector, index);
1780 targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1781 extractOp, dstIndex);
1783 rewriter.replaceOp(op, targetOp);
1784 return success();
1787 } // namespace
1789 //===----------------------------------------------------------------------===//
1790 // Pattern population
1791 //===----------------------------------------------------------------------===//
1793 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
1794 spirv::ClientAPI clientAPI) {
1795 typeConverter.addConversion([&](spirv::ArrayType type) {
1796 return convertArrayType(type, typeConverter);
1798 typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1799 return convertPointerType(type, typeConverter, clientAPI);
1801 typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1802 return convertRuntimeArrayType(type, typeConverter);
1804 typeConverter.addConversion([&](spirv::StructType type) {
1805 return convertStructType(type, typeConverter);
1809 void mlir::populateSPIRVToLLVMConversionPatterns(
1810 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1811 spirv::ClientAPI clientAPI) {
1812 patterns.add<
1813 // Arithmetic ops
1814 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1815 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1816 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1817 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1818 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1819 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1820 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1821 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1822 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1823 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1824 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1825 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1826 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1828 // Bitwise ops
1829 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1830 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1831 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1832 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1833 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1834 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1835 NotPattern<spirv::NotOp>,
1837 // Cast ops
1838 BitcastConversionPattern,
1839 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1840 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1841 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1842 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1843 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1844 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1845 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1847 // Comparison ops
1848 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1849 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1850 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1851 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1852 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1853 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1854 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1855 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1856 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1857 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1858 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1859 LLVM::FCmpPredicate::uge>,
1860 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1861 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1862 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1863 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1864 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1865 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1866 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1867 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1868 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1869 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1870 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1872 // Constant op
1873 ConstantScalarAndVectorPattern,
1875 // Control Flow ops
1876 BranchConversionPattern, BranchConditionalConversionPattern,
1877 FunctionCallPattern, LoopPattern, SelectionPattern,
1878 ErasePattern<spirv::MergeOp>,
1880 // Entry points and execution mode are handled separately.
1881 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1883 // GLSL extended instruction set ops
1884 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1885 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1886 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1887 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1888 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1889 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1890 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1891 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1892 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1893 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1894 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1895 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1896 InverseSqrtPattern, TanPattern, TanhPattern,
1898 // Logical ops
1899 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1900 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1901 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1902 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1903 NotPattern<spirv::LogicalNotOp>,
1905 // Memory ops
1906 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1907 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1909 // Miscellaneous ops
1910 CompositeExtractPattern, CompositeInsertPattern,
1911 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1912 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1913 VectorShufflePattern,
1915 // Shift ops
1916 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1917 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1918 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1920 // Return ops
1921 ReturnPattern, ReturnValuePattern,
1923 // Barrier ops
1924 ControlBarrierPattern<spirv::ControlBarrierOp>,
1925 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1926 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1928 // Group reduction operations
1929 GroupReducePattern<spirv::GroupIAddOp>,
1930 GroupReducePattern<spirv::GroupFAddOp>,
1931 GroupReducePattern<spirv::GroupFMinOp>,
1932 GroupReducePattern<spirv::GroupUMinOp>,
1933 GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1934 GroupReducePattern<spirv::GroupFMaxOp>,
1935 GroupReducePattern<spirv::GroupUMaxOp>,
1936 GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1937 GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1938 /*NonUniform=*/true>,
1939 GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1940 /*NonUniform=*/true>,
1941 GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1942 /*NonUniform=*/true>,
1943 GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1944 /*NonUniform=*/true>,
1945 GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1946 /*NonUniform=*/true>,
1947 GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1948 /*NonUniform=*/true>,
1949 GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1950 /*NonUniform=*/true>,
1951 GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1952 /*NonUniform=*/true>,
1953 GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1954 /*NonUniform=*/true>,
1955 GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1956 /*NonUniform=*/true>,
1957 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1958 /*NonUniform=*/true>,
1959 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1960 /*NonUniform=*/true>,
1961 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1962 /*NonUniform=*/true>,
1963 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1964 /*NonUniform=*/true>,
1965 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1966 /*NonUniform=*/true>,
1967 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1968 /*NonUniform=*/true>>(patterns.getContext(),
1969 typeConverter);
1971 patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1972 typeConverter);
1975 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1976 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1977 patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1980 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1981 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1982 patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1985 //===----------------------------------------------------------------------===//
1986 // Pre-conversion hooks
1987 //===----------------------------------------------------------------------===//
1989 /// Hook for descriptor set and binding number encoding.
1990 static constexpr StringRef kBinding = "binding";
1991 static constexpr StringRef kDescriptorSet = "descriptor_set";
1992 void mlir::encodeBindAttribute(ModuleOp module) {
1993 auto spvModules = module.getOps<spirv::ModuleOp>();
1994 for (auto spvModule : spvModules) {
1995 spvModule.walk([&](spirv::GlobalVariableOp op) {
1996 IntegerAttr descriptorSet =
1997 op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1998 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1999 // For every global variable in the module, get the ones with descriptor
2000 // set and binding numbers.
2001 if (descriptorSet && binding) {
2002 // Encode these numbers into the variable's symbolic name. If the
2003 // SPIR-V module has a name, add it at the beginning.
2004 auto moduleAndName =
2005 spvModule.getName().has_value()
2006 ? spvModule.getName()->str() + "_" + op.getSymName().str()
2007 : op.getSymName().str();
2008 std::string name =
2009 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2010 std::to_string(descriptorSet.getInt()),
2011 std::to_string(binding.getInt()));
2012 auto nameAttr = StringAttr::get(op->getContext(), name);
2014 // Replace all symbol uses and set the new symbol name. Finally, remove
2015 // descriptor set and binding attributes.
2016 if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2017 op.emitError("unable to replace all symbol uses for ") << name;
2018 SymbolTable::setSymbolName(op, nameAttr);
2019 op->removeAttr(kDescriptorSet);
2020 op->removeAttr(kBinding);