1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
29 #define DEBUG_TYPE "spirv-to-llvm-pattern"
33 //===----------------------------------------------------------------------===//
35 //===----------------------------------------------------------------------===//
37 /// Returns true if the given type is a signed integer or vector type.
38 static bool isSignedIntegerOrVector(Type type
) {
39 if (type
.isSignedInteger())
41 if (auto vecType
= dyn_cast
<VectorType
>(type
))
42 return vecType
.getElementType().isSignedInteger();
46 /// Returns true if the given type is an unsigned integer or vector type
47 static bool isUnsignedIntegerOrVector(Type type
) {
48 if (type
.isUnsignedInteger())
50 if (auto vecType
= dyn_cast
<VectorType
>(type
))
51 return vecType
.getElementType().isUnsignedInteger();
55 /// Returns the width of an integer or of the element type of an integer vector,
57 static std::optional
<uint64_t> getIntegerOrVectorElementWidth(Type type
) {
58 if (auto intType
= dyn_cast
<IntegerType
>(type
))
59 return intType
.getWidth();
60 if (auto vecType
= dyn_cast
<VectorType
>(type
))
61 if (auto intType
= dyn_cast
<IntegerType
>(vecType
.getElementType()))
62 return intType
.getWidth();
66 /// Returns the bit width of integer, float or vector of float or integer values
67 static unsigned getBitWidth(Type type
) {
68 assert((type
.isIntOrFloat() || isa
<VectorType
>(type
)) &&
69 "bitwidth is not supported for this type");
70 if (type
.isIntOrFloat())
71 return type
.getIntOrFloatBitWidth();
72 auto vecType
= dyn_cast
<VectorType
>(type
);
73 auto elementType
= vecType
.getElementType();
74 assert(elementType
.isIntOrFloat() &&
75 "only integers and floats have a bitwidth");
76 return elementType
.getIntOrFloatBitWidth();
79 /// Returns the bit width of LLVMType integer or vector.
80 static unsigned getLLVMTypeBitWidth(Type type
) {
81 return cast
<IntegerType
>((LLVM::isCompatibleVectorType(type
)
82 ? LLVM::getVectorElementType(type
)
87 /// Creates `IntegerAttribute` with all bits set for given type
88 static IntegerAttr
minusOneIntegerAttribute(Type type
, Builder builder
) {
89 if (auto vecType
= dyn_cast
<VectorType
>(type
)) {
90 auto integerType
= cast
<IntegerType
>(vecType
.getElementType());
91 return builder
.getIntegerAttr(integerType
, -1);
93 auto integerType
= cast
<IntegerType
>(type
);
94 return builder
.getIntegerAttr(integerType
, -1);
97 /// Creates `llvm.mlir.constant` with all bits set for the given type.
98 static Value
createConstantAllBitsSet(Location loc
, Type srcType
, Type dstType
,
99 PatternRewriter
&rewriter
) {
100 if (isa
<VectorType
>(srcType
)) {
101 return rewriter
.create
<LLVM::ConstantOp
>(
103 SplatElementsAttr::get(cast
<ShapedType
>(srcType
),
104 minusOneIntegerAttribute(srcType
, rewriter
)));
106 return rewriter
.create
<LLVM::ConstantOp
>(
107 loc
, dstType
, minusOneIntegerAttribute(srcType
, rewriter
));
110 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
111 static Value
createFPConstant(Location loc
, Type srcType
, Type dstType
,
112 PatternRewriter
&rewriter
, double value
) {
113 if (auto vecType
= dyn_cast
<VectorType
>(srcType
)) {
114 auto floatType
= cast
<FloatType
>(vecType
.getElementType());
115 return rewriter
.create
<LLVM::ConstantOp
>(
117 SplatElementsAttr::get(vecType
,
118 rewriter
.getFloatAttr(floatType
, value
)));
120 auto floatType
= cast
<FloatType
>(srcType
);
121 return rewriter
.create
<LLVM::ConstantOp
>(
122 loc
, dstType
, rewriter
.getFloatAttr(floatType
, value
));
125 /// Utility function for bitfield ops:
126 /// - `BitFieldInsert`
127 /// - `BitFieldSExtract`
128 /// - `BitFieldUExtract`
129 /// Truncates or extends the value. If the bitwidth of the value is the same as
130 /// `llvmType` bitwidth, the value remains unchanged.
131 static Value
optionallyTruncateOrExtend(Location loc
, Value value
,
133 PatternRewriter
&rewriter
) {
134 auto srcType
= value
.getType();
135 unsigned targetBitWidth
= getLLVMTypeBitWidth(llvmType
);
136 unsigned valueBitWidth
= LLVM::isCompatibleType(srcType
)
137 ? getLLVMTypeBitWidth(srcType
)
138 : getBitWidth(srcType
);
140 if (valueBitWidth
< targetBitWidth
)
141 return rewriter
.create
<LLVM::ZExtOp
>(loc
, llvmType
, value
);
142 // If the bit widths of `Count` and `Offset` are greater than the bit width
143 // of the target type, they are truncated. Truncation is safe since `Count`
144 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
145 // both values can be expressed in 8 bits.
146 if (valueBitWidth
> targetBitWidth
)
147 return rewriter
.create
<LLVM::TruncOp
>(loc
, llvmType
, value
);
151 /// Broadcasts the value to vector with `numElements` number of elements.
152 static Value
broadcast(Location loc
, Value toBroadcast
, unsigned numElements
,
153 const TypeConverter
&typeConverter
,
154 ConversionPatternRewriter
&rewriter
) {
155 auto vectorType
= VectorType::get(numElements
, toBroadcast
.getType());
156 auto llvmVectorType
= typeConverter
.convertType(vectorType
);
157 auto llvmI32Type
= typeConverter
.convertType(rewriter
.getIntegerType(32));
158 Value broadcasted
= rewriter
.create
<LLVM::UndefOp
>(loc
, llvmVectorType
);
159 for (unsigned i
= 0; i
< numElements
; ++i
) {
160 auto index
= rewriter
.create
<LLVM::ConstantOp
>(
161 loc
, llvmI32Type
, rewriter
.getI32IntegerAttr(i
));
162 broadcasted
= rewriter
.create
<LLVM::InsertElementOp
>(
163 loc
, llvmVectorType
, broadcasted
, toBroadcast
, index
);
168 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
169 static Value
optionallyBroadcast(Location loc
, Value value
, Type srcType
,
170 const TypeConverter
&typeConverter
,
171 ConversionPatternRewriter
&rewriter
) {
172 if (auto vectorType
= dyn_cast
<VectorType
>(srcType
)) {
173 unsigned numElements
= vectorType
.getNumElements();
174 return broadcast(loc
, value
, numElements
, typeConverter
, rewriter
);
179 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
180 /// `BitFieldUExtract`.
181 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
182 /// a vector type, construct a vector that has:
183 /// - same number of elements as `Base`
184 /// - each element has the type that is the same as the type of `Offset` or
186 /// - each element has the same value as `Offset` or `Count`
187 /// Then cast `Offset` and `Count` if their bit width is different
188 /// from `Base` bit width.
189 static Value
processCountOrOffset(Location loc
, Value value
, Type srcType
,
190 Type dstType
, const TypeConverter
&converter
,
191 ConversionPatternRewriter
&rewriter
) {
193 optionallyBroadcast(loc
, value
, srcType
, converter
, rewriter
);
194 return optionallyTruncateOrExtend(loc
, broadcasted
, dstType
, rewriter
);
197 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
198 /// offset to LLVM struct. Otherwise, the conversion is not supported.
199 static Type
convertStructTypeWithOffset(spirv::StructType type
,
200 const TypeConverter
&converter
) {
201 if (type
!= VulkanLayoutUtils::decorateType(type
))
204 SmallVector
<Type
> elementsVector
;
205 if (failed(converter
.convertTypes(type
.getElementTypes(), elementsVector
)))
207 return LLVM::LLVMStructType::getLiteral(type
.getContext(), elementsVector
,
211 /// Converts SPIR-V struct with no offset to packed LLVM struct.
212 static Type
convertStructTypePacked(spirv::StructType type
,
213 const TypeConverter
&converter
) {
214 SmallVector
<Type
> elementsVector
;
215 if (failed(converter
.convertTypes(type
.getElementTypes(), elementsVector
)))
217 return LLVM::LLVMStructType::getLiteral(type
.getContext(), elementsVector
,
221 /// Creates LLVM dialect constant with the given value.
222 static Value
createI32ConstantOf(Location loc
, PatternRewriter
&rewriter
,
224 return rewriter
.create
<LLVM::ConstantOp
>(
225 loc
, IntegerType::get(rewriter
.getContext(), 32),
226 rewriter
.getIntegerAttr(rewriter
.getI32Type(), value
));
229 /// Utility for `spirv.Load` and `spirv.Store` conversion.
230 static LogicalResult
replaceWithLoadOrStore(Operation
*op
, ValueRange operands
,
231 ConversionPatternRewriter
&rewriter
,
232 const TypeConverter
&typeConverter
,
233 unsigned alignment
, bool isVolatile
,
234 bool isNonTemporal
) {
235 if (auto loadOp
= dyn_cast
<spirv::LoadOp
>(op
)) {
236 auto dstType
= typeConverter
.convertType(loadOp
.getType());
238 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
239 rewriter
.replaceOpWithNewOp
<LLVM::LoadOp
>(
240 loadOp
, dstType
, spirv::LoadOpAdaptor(operands
).getPtr(), alignment
,
241 isVolatile
, isNonTemporal
);
244 auto storeOp
= cast
<spirv::StoreOp
>(op
);
245 spirv::StoreOpAdaptor
adaptor(operands
);
246 rewriter
.replaceOpWithNewOp
<LLVM::StoreOp
>(storeOp
, adaptor
.getValue(),
247 adaptor
.getPtr(), alignment
,
248 isVolatile
, isNonTemporal
);
252 //===----------------------------------------------------------------------===//
254 //===----------------------------------------------------------------------===//
256 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
257 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
258 /// when converting ops that manipulate array types.
259 static std::optional
<Type
> convertArrayType(spirv::ArrayType type
,
260 TypeConverter
&converter
) {
261 unsigned stride
= type
.getArrayStride();
262 Type elementType
= type
.getElementType();
263 auto sizeInBytes
= cast
<spirv::SPIRVType
>(elementType
).getSizeInBytes();
264 if (stride
!= 0 && (!sizeInBytes
|| *sizeInBytes
!= stride
))
267 auto llvmElementType
= converter
.convertType(elementType
);
268 unsigned numElements
= type
.getNumElements();
269 return LLVM::LLVMArrayType::get(llvmElementType
, numElements
);
272 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
273 /// modelled at the moment.
274 static Type
convertPointerType(spirv::PointerType type
,
275 const TypeConverter
&converter
,
276 spirv::ClientAPI clientAPI
) {
277 unsigned addressSpace
=
278 storageClassToAddressSpace(clientAPI
, type
.getStorageClass());
279 return LLVM::LLVMPointerType::get(type
.getContext(), addressSpace
);
282 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
283 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
284 /// no modelling of array stride at the moment.
285 static std::optional
<Type
> convertRuntimeArrayType(spirv::RuntimeArrayType type
,
286 TypeConverter
&converter
) {
287 if (type
.getArrayStride() != 0)
289 auto elementType
= converter
.convertType(type
.getElementType());
290 return LLVM::LLVMArrayType::get(elementType
, 0);
293 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
294 /// member decorations. Also, only natural offset is supported.
295 static Type
convertStructType(spirv::StructType type
,
296 const TypeConverter
&converter
) {
297 SmallVector
<spirv::StructType::MemberDecorationInfo
, 4> memberDecorations
;
298 type
.getMemberDecorations(memberDecorations
);
299 if (!memberDecorations
.empty())
301 if (type
.hasOffset())
302 return convertStructTypeWithOffset(type
, converter
);
303 return convertStructTypePacked(type
, converter
);
306 //===----------------------------------------------------------------------===//
307 // Operation conversion
308 //===----------------------------------------------------------------------===//
312 class AccessChainPattern
: public SPIRVToLLVMConversion
<spirv::AccessChainOp
> {
314 using SPIRVToLLVMConversion
<spirv::AccessChainOp
>::SPIRVToLLVMConversion
;
317 matchAndRewrite(spirv::AccessChainOp op
, OpAdaptor adaptor
,
318 ConversionPatternRewriter
&rewriter
) const override
{
320 getTypeConverter()->convertType(op
.getComponentPtr().getType());
322 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
323 // To use GEP we need to add a first 0 index to go through the pointer.
324 auto indices
= llvm::to_vector
<4>(adaptor
.getIndices());
325 Type indexType
= op
.getIndices().front().getType();
326 auto llvmIndexType
= getTypeConverter()->convertType(indexType
);
328 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
329 Value zero
= rewriter
.create
<LLVM::ConstantOp
>(
330 op
.getLoc(), llvmIndexType
, rewriter
.getIntegerAttr(indexType
, 0));
331 indices
.insert(indices
.begin(), zero
);
333 auto elementType
= getTypeConverter()->convertType(
334 cast
<spirv::PointerType
>(op
.getBasePtr().getType()).getPointeeType());
336 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
337 rewriter
.replaceOpWithNewOp
<LLVM::GEPOp
>(op
, dstType
, elementType
,
338 adaptor
.getBasePtr(), indices
);
343 class AddressOfPattern
: public SPIRVToLLVMConversion
<spirv::AddressOfOp
> {
345 using SPIRVToLLVMConversion
<spirv::AddressOfOp
>::SPIRVToLLVMConversion
;
348 matchAndRewrite(spirv::AddressOfOp op
, OpAdaptor adaptor
,
349 ConversionPatternRewriter
&rewriter
) const override
{
350 auto dstType
= getTypeConverter()->convertType(op
.getPointer().getType());
352 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
353 rewriter
.replaceOpWithNewOp
<LLVM::AddressOfOp
>(op
, dstType
,
359 class BitFieldInsertPattern
360 : public SPIRVToLLVMConversion
<spirv::BitFieldInsertOp
> {
362 using SPIRVToLLVMConversion
<spirv::BitFieldInsertOp
>::SPIRVToLLVMConversion
;
365 matchAndRewrite(spirv::BitFieldInsertOp op
, OpAdaptor adaptor
,
366 ConversionPatternRewriter
&rewriter
) const override
{
367 auto srcType
= op
.getType();
368 auto dstType
= getTypeConverter()->convertType(srcType
);
370 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
371 Location loc
= op
.getLoc();
373 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
374 Value offset
= processCountOrOffset(loc
, op
.getOffset(), srcType
, dstType
,
375 *getTypeConverter(), rewriter
);
376 Value count
= processCountOrOffset(loc
, op
.getCount(), srcType
, dstType
,
377 *getTypeConverter(), rewriter
);
379 // Create a mask with bits set outside [Offset, Offset + Count - 1].
380 Value minusOne
= createConstantAllBitsSet(loc
, srcType
, dstType
, rewriter
);
381 Value maskShiftedByCount
=
382 rewriter
.create
<LLVM::ShlOp
>(loc
, dstType
, minusOne
, count
);
383 Value negated
= rewriter
.create
<LLVM::XOrOp
>(loc
, dstType
,
384 maskShiftedByCount
, minusOne
);
385 Value maskShiftedByCountAndOffset
=
386 rewriter
.create
<LLVM::ShlOp
>(loc
, dstType
, negated
, offset
);
387 Value mask
= rewriter
.create
<LLVM::XOrOp
>(
388 loc
, dstType
, maskShiftedByCountAndOffset
, minusOne
);
390 // Extract unchanged bits from the `Base` that are outside of
391 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
393 rewriter
.create
<LLVM::AndOp
>(loc
, dstType
, op
.getBase(), mask
);
394 Value insertShiftedByOffset
=
395 rewriter
.create
<LLVM::ShlOp
>(loc
, dstType
, op
.getInsert(), offset
);
396 rewriter
.replaceOpWithNewOp
<LLVM::OrOp
>(op
, dstType
, baseAndMask
,
397 insertShiftedByOffset
);
402 /// Converts SPIR-V ConstantOp with scalar or vector type.
403 class ConstantScalarAndVectorPattern
404 : public SPIRVToLLVMConversion
<spirv::ConstantOp
> {
406 using SPIRVToLLVMConversion
<spirv::ConstantOp
>::SPIRVToLLVMConversion
;
409 matchAndRewrite(spirv::ConstantOp constOp
, OpAdaptor adaptor
,
410 ConversionPatternRewriter
&rewriter
) const override
{
411 auto srcType
= constOp
.getType();
412 if (!isa
<VectorType
>(srcType
) && !srcType
.isIntOrFloat())
415 auto dstType
= getTypeConverter()->convertType(srcType
);
417 return rewriter
.notifyMatchFailure(constOp
, "type conversion failed");
419 // SPIR-V constant can be a signed/unsigned integer, which has to be
420 // casted to signless integer when converting to LLVM dialect. Removing the
421 // sign bit may have unexpected behaviour. However, it is better to handle
422 // it case-by-case, given that the purpose of the conversion is not to
423 // cover all possible corner cases.
424 if (isSignedIntegerOrVector(srcType
) ||
425 isUnsignedIntegerOrVector(srcType
)) {
426 auto signlessType
= rewriter
.getIntegerType(getBitWidth(srcType
));
428 if (isa
<VectorType
>(srcType
)) {
429 auto dstElementsAttr
= cast
<DenseIntElementsAttr
>(constOp
.getValue());
430 rewriter
.replaceOpWithNewOp
<LLVM::ConstantOp
>(
432 dstElementsAttr
.mapValues(
433 signlessType
, [&](const APInt
&value
) { return value
; }));
436 auto srcAttr
= cast
<IntegerAttr
>(constOp
.getValue());
437 auto dstAttr
= rewriter
.getIntegerAttr(signlessType
, srcAttr
.getValue());
438 rewriter
.replaceOpWithNewOp
<LLVM::ConstantOp
>(constOp
, dstType
, dstAttr
);
441 rewriter
.replaceOpWithNewOp
<LLVM::ConstantOp
>(
442 constOp
, dstType
, adaptor
.getOperands(), constOp
->getAttrs());
447 class BitFieldSExtractPattern
448 : public SPIRVToLLVMConversion
<spirv::BitFieldSExtractOp
> {
450 using SPIRVToLLVMConversion
<spirv::BitFieldSExtractOp
>::SPIRVToLLVMConversion
;
453 matchAndRewrite(spirv::BitFieldSExtractOp op
, OpAdaptor adaptor
,
454 ConversionPatternRewriter
&rewriter
) const override
{
455 auto srcType
= op
.getType();
456 auto dstType
= getTypeConverter()->convertType(srcType
);
458 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
459 Location loc
= op
.getLoc();
461 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
462 Value offset
= processCountOrOffset(loc
, op
.getOffset(), srcType
, dstType
,
463 *getTypeConverter(), rewriter
);
464 Value count
= processCountOrOffset(loc
, op
.getCount(), srcType
, dstType
,
465 *getTypeConverter(), rewriter
);
467 // Create a constant that holds the size of the `Base`.
468 IntegerType integerType
;
469 if (auto vecType
= dyn_cast
<VectorType
>(srcType
))
470 integerType
= cast
<IntegerType
>(vecType
.getElementType());
472 integerType
= cast
<IntegerType
>(srcType
);
474 auto baseSize
= rewriter
.getIntegerAttr(integerType
, getBitWidth(srcType
));
476 isa
<VectorType
>(srcType
)
477 ? rewriter
.create
<LLVM::ConstantOp
>(
479 SplatElementsAttr::get(cast
<ShapedType
>(srcType
), baseSize
))
480 : rewriter
.create
<LLVM::ConstantOp
>(loc
, dstType
, baseSize
);
482 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
483 // at Offset + Count - 1 is the most significant bit now.
484 Value countPlusOffset
=
485 rewriter
.create
<LLVM::AddOp
>(loc
, dstType
, count
, offset
);
486 Value amountToShiftLeft
=
487 rewriter
.create
<LLVM::SubOp
>(loc
, dstType
, size
, countPlusOffset
);
488 Value baseShiftedLeft
= rewriter
.create
<LLVM::ShlOp
>(
489 loc
, dstType
, op
.getBase(), amountToShiftLeft
);
491 // Shift the result right, filling the bits with the sign bit.
492 Value amountToShiftRight
=
493 rewriter
.create
<LLVM::AddOp
>(loc
, dstType
, offset
, amountToShiftLeft
);
494 rewriter
.replaceOpWithNewOp
<LLVM::AShrOp
>(op
, dstType
, baseShiftedLeft
,
500 class BitFieldUExtractPattern
501 : public SPIRVToLLVMConversion
<spirv::BitFieldUExtractOp
> {
503 using SPIRVToLLVMConversion
<spirv::BitFieldUExtractOp
>::SPIRVToLLVMConversion
;
506 matchAndRewrite(spirv::BitFieldUExtractOp op
, OpAdaptor adaptor
,
507 ConversionPatternRewriter
&rewriter
) const override
{
508 auto srcType
= op
.getType();
509 auto dstType
= getTypeConverter()->convertType(srcType
);
511 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
512 Location loc
= op
.getLoc();
514 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
515 Value offset
= processCountOrOffset(loc
, op
.getOffset(), srcType
, dstType
,
516 *getTypeConverter(), rewriter
);
517 Value count
= processCountOrOffset(loc
, op
.getCount(), srcType
, dstType
,
518 *getTypeConverter(), rewriter
);
520 // Create a mask with bits set at [0, Count - 1].
521 Value minusOne
= createConstantAllBitsSet(loc
, srcType
, dstType
, rewriter
);
522 Value maskShiftedByCount
=
523 rewriter
.create
<LLVM::ShlOp
>(loc
, dstType
, minusOne
, count
);
524 Value mask
= rewriter
.create
<LLVM::XOrOp
>(loc
, dstType
, maskShiftedByCount
,
527 // Shift `Base` by `Offset` and apply the mask on it.
529 rewriter
.create
<LLVM::LShrOp
>(loc
, dstType
, op
.getBase(), offset
);
530 rewriter
.replaceOpWithNewOp
<LLVM::AndOp
>(op
, dstType
, shiftedBase
, mask
);
535 class BranchConversionPattern
: public SPIRVToLLVMConversion
<spirv::BranchOp
> {
537 using SPIRVToLLVMConversion
<spirv::BranchOp
>::SPIRVToLLVMConversion
;
540 matchAndRewrite(spirv::BranchOp branchOp
, OpAdaptor adaptor
,
541 ConversionPatternRewriter
&rewriter
) const override
{
542 rewriter
.replaceOpWithNewOp
<LLVM::BrOp
>(branchOp
, adaptor
.getOperands(),
543 branchOp
.getTarget());
548 class BranchConditionalConversionPattern
549 : public SPIRVToLLVMConversion
<spirv::BranchConditionalOp
> {
551 using SPIRVToLLVMConversion
<
552 spirv::BranchConditionalOp
>::SPIRVToLLVMConversion
;
555 matchAndRewrite(spirv::BranchConditionalOp op
, OpAdaptor adaptor
,
556 ConversionPatternRewriter
&rewriter
) const override
{
557 // If branch weights exist, map them to 32-bit integer vector.
558 DenseI32ArrayAttr branchWeights
= nullptr;
559 if (auto weights
= op
.getBranchWeights()) {
560 SmallVector
<int32_t> weightValues
;
561 for (auto weight
: weights
->getAsRange
<IntegerAttr
>())
562 weightValues
.push_back(weight
.getInt());
563 branchWeights
= DenseI32ArrayAttr::get(getContext(), weightValues
);
566 rewriter
.replaceOpWithNewOp
<LLVM::CondBrOp
>(
567 op
, op
.getCondition(), op
.getTrueBlockArguments(),
568 op
.getFalseBlockArguments(), branchWeights
, op
.getTrueBlock(),
574 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
575 /// type is an aggregate type (struct or array). Otherwise, converts to
576 /// `llvm.extractelement` that operates on vectors.
577 class CompositeExtractPattern
578 : public SPIRVToLLVMConversion
<spirv::CompositeExtractOp
> {
580 using SPIRVToLLVMConversion
<spirv::CompositeExtractOp
>::SPIRVToLLVMConversion
;
583 matchAndRewrite(spirv::CompositeExtractOp op
, OpAdaptor adaptor
,
584 ConversionPatternRewriter
&rewriter
) const override
{
585 auto dstType
= this->getTypeConverter()->convertType(op
.getType());
587 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
589 Type containerType
= op
.getComposite().getType();
590 if (isa
<VectorType
>(containerType
)) {
591 Location loc
= op
.getLoc();
592 IntegerAttr value
= cast
<IntegerAttr
>(op
.getIndices()[0]);
593 Value index
= createI32ConstantOf(loc
, rewriter
, value
.getInt());
594 rewriter
.replaceOpWithNewOp
<LLVM::ExtractElementOp
>(
595 op
, dstType
, adaptor
.getComposite(), index
);
599 rewriter
.replaceOpWithNewOp
<LLVM::ExtractValueOp
>(
600 op
, adaptor
.getComposite(),
601 LLVM::convertArrayToIndices(op
.getIndices()));
606 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
607 /// type is an aggregate type (struct or array). Otherwise, converts to
608 /// `llvm.insertelement` that operates on vectors.
609 class CompositeInsertPattern
610 : public SPIRVToLLVMConversion
<spirv::CompositeInsertOp
> {
612 using SPIRVToLLVMConversion
<spirv::CompositeInsertOp
>::SPIRVToLLVMConversion
;
615 matchAndRewrite(spirv::CompositeInsertOp op
, OpAdaptor adaptor
,
616 ConversionPatternRewriter
&rewriter
) const override
{
617 auto dstType
= this->getTypeConverter()->convertType(op
.getType());
619 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
621 Type containerType
= op
.getComposite().getType();
622 if (isa
<VectorType
>(containerType
)) {
623 Location loc
= op
.getLoc();
624 IntegerAttr value
= cast
<IntegerAttr
>(op
.getIndices()[0]);
625 Value index
= createI32ConstantOf(loc
, rewriter
, value
.getInt());
626 rewriter
.replaceOpWithNewOp
<LLVM::InsertElementOp
>(
627 op
, dstType
, adaptor
.getComposite(), adaptor
.getObject(), index
);
631 rewriter
.replaceOpWithNewOp
<LLVM::InsertValueOp
>(
632 op
, adaptor
.getComposite(), adaptor
.getObject(),
633 LLVM::convertArrayToIndices(op
.getIndices()));
638 /// Converts SPIR-V operations that have straightforward LLVM equivalent
639 /// into LLVM dialect operations.
640 template <typename SPIRVOp
, typename LLVMOp
>
641 class DirectConversionPattern
: public SPIRVToLLVMConversion
<SPIRVOp
> {
643 using SPIRVToLLVMConversion
<SPIRVOp
>::SPIRVToLLVMConversion
;
646 matchAndRewrite(SPIRVOp op
, typename
SPIRVOp::Adaptor adaptor
,
647 ConversionPatternRewriter
&rewriter
) const override
{
648 auto dstType
= this->getTypeConverter()->convertType(op
.getType());
650 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
651 rewriter
.template replaceOpWithNewOp
<LLVMOp
>(
652 op
, dstType
, adaptor
.getOperands(), op
->getAttrs());
657 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
658 /// execution mode information.
659 class ExecutionModePattern
660 : public SPIRVToLLVMConversion
<spirv::ExecutionModeOp
> {
662 using SPIRVToLLVMConversion
<spirv::ExecutionModeOp
>::SPIRVToLLVMConversion
;
665 matchAndRewrite(spirv::ExecutionModeOp op
, OpAdaptor adaptor
,
666 ConversionPatternRewriter
&rewriter
) const override
{
667 // First, create the global struct's name that would be associated with
668 // this entry point's execution mode. We set it to be:
669 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
670 ModuleOp module
= op
->getParentOfType
<ModuleOp
>();
671 spirv::ExecutionModeAttr executionModeAttr
= op
.getExecutionModeAttr();
672 std::string moduleName
;
673 if (module
.getName().has_value())
674 moduleName
= "_" + module
.getName()->str();
677 std::string executionModeInfoName
= llvm::formatv(
678 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName
, op
.getFn().str(),
679 static_cast<uint32_t>(executionModeAttr
.getValue()));
681 MLIRContext
*context
= rewriter
.getContext();
682 OpBuilder::InsertionGuard
guard(rewriter
);
683 rewriter
.setInsertionPointToStart(module
.getBody());
685 // Create a struct type, corresponding to the C struct below.
687 // int32_t executionMode;
688 // int32_t values[]; // optional values
690 auto llvmI32Type
= IntegerType::get(context
, 32);
691 SmallVector
<Type
, 2> fields
;
692 fields
.push_back(llvmI32Type
);
693 ArrayAttr values
= op
.getValues();
694 if (!values
.empty()) {
695 auto arrayType
= LLVM::LLVMArrayType::get(llvmI32Type
, values
.size());
696 fields
.push_back(arrayType
);
698 auto structType
= LLVM::LLVMStructType::getLiteral(context
, fields
);
700 // Create `llvm.mlir.global` with initializer region containing one block.
701 auto global
= rewriter
.create
<LLVM::GlobalOp
>(
702 UnknownLoc::get(context
), structType
, /*isConstant=*/true,
703 LLVM::Linkage::External
, executionModeInfoName
, Attribute(),
705 Location loc
= global
.getLoc();
706 Region
®ion
= global
.getInitializerRegion();
707 Block
*block
= rewriter
.createBlock(®ion
);
709 // Initialize the struct and set the execution mode value.
710 rewriter
.setInsertionPointToStart(block
);
711 Value structValue
= rewriter
.create
<LLVM::UndefOp
>(loc
, structType
);
712 Value executionMode
= rewriter
.create
<LLVM::ConstantOp
>(
714 rewriter
.getI32IntegerAttr(
715 static_cast<uint32_t>(executionModeAttr
.getValue())));
716 structValue
= rewriter
.create
<LLVM::InsertValueOp
>(loc
, structValue
,
719 // Insert extra operands if they exist into execution mode info struct.
720 for (unsigned i
= 0, e
= values
.size(); i
< e
; ++i
) {
721 auto attr
= values
.getValue()[i
];
722 Value entry
= rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmI32Type
, attr
);
723 structValue
= rewriter
.create
<LLVM::InsertValueOp
>(
724 loc
, structValue
, entry
, ArrayRef
<int64_t>({1, i
}));
726 rewriter
.create
<LLVM::ReturnOp
>(loc
, ArrayRef
<Value
>({structValue
}));
727 rewriter
.eraseOp(op
);
732 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
733 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
734 /// value. This difference is handled by `spirv.mlir.addressof` and
735 /// `llvm.mlir.addressof`ops that both return a pointer.
736 class GlobalVariablePattern
737 : public SPIRVToLLVMConversion
<spirv::GlobalVariableOp
> {
739 template <typename
... Args
>
740 GlobalVariablePattern(spirv::ClientAPI clientAPI
, Args
&&...args
)
741 : SPIRVToLLVMConversion
<spirv::GlobalVariableOp
>(
742 std::forward
<Args
>(args
)...),
743 clientAPI(clientAPI
) {}
746 matchAndRewrite(spirv::GlobalVariableOp op
, OpAdaptor adaptor
,
747 ConversionPatternRewriter
&rewriter
) const override
{
748 // Currently, there is no support of initialization with a constant value in
749 // SPIR-V dialect. Specialization constants are not considered as well.
750 if (op
.getInitializer())
753 auto srcType
= cast
<spirv::PointerType
>(op
.getType());
754 auto dstType
= getTypeConverter()->convertType(srcType
.getPointeeType());
756 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
758 // Limit conversion to the current invocation only or `StorageBuffer`
759 // required by SPIR-V runner.
760 // This is okay because multiple invocations are not supported yet.
761 auto storageClass
= srcType
.getStorageClass();
762 switch (storageClass
) {
763 case spirv::StorageClass::Input
:
764 case spirv::StorageClass::Private
:
765 case spirv::StorageClass::Output
:
766 case spirv::StorageClass::StorageBuffer
:
767 case spirv::StorageClass::UniformConstant
:
773 // LLVM dialect spec: "If the global value is a constant, storing into it is
774 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
775 // storage class that is read-only.
776 bool isConstant
= (storageClass
== spirv::StorageClass::Input
) ||
777 (storageClass
== spirv::StorageClass::UniformConstant
);
778 // SPIR-V spec: "By default, functions and global variables are private to a
779 // module and cannot be accessed by other modules. However, a module may be
780 // written to export or import functions and global (module scope)
781 // variables.". Therefore, map 'Private' storage class to private linkage,
782 // 'Input' and 'Output' to external linkage.
783 auto linkage
= storageClass
== spirv::StorageClass::Private
784 ? LLVM::Linkage::Private
785 : LLVM::Linkage::External
;
786 auto newGlobalOp
= rewriter
.replaceOpWithNewOp
<LLVM::GlobalOp
>(
787 op
, dstType
, isConstant
, linkage
, op
.getSymName(), Attribute(),
788 /*alignment=*/0, storageClassToAddressSpace(clientAPI
, storageClass
));
790 // Attach location attribute if applicable
791 if (op
.getLocationAttr())
792 newGlobalOp
->setAttr(op
.getLocationAttrName(), op
.getLocationAttr());
798 spirv::ClientAPI clientAPI
;
801 /// Converts SPIR-V cast ops that do not have straightforward LLVM
802 /// equivalent in LLVM dialect.
803 template <typename SPIRVOp
, typename LLVMExtOp
, typename LLVMTruncOp
>
804 class IndirectCastPattern
: public SPIRVToLLVMConversion
<SPIRVOp
> {
806 using SPIRVToLLVMConversion
<SPIRVOp
>::SPIRVToLLVMConversion
;
809 matchAndRewrite(SPIRVOp op
, typename
SPIRVOp::Adaptor adaptor
,
810 ConversionPatternRewriter
&rewriter
) const override
{
812 Type fromType
= op
.getOperand().getType();
813 Type toType
= op
.getType();
815 auto dstType
= this->getTypeConverter()->convertType(toType
);
817 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
819 if (getBitWidth(fromType
) < getBitWidth(toType
)) {
820 rewriter
.template replaceOpWithNewOp
<LLVMExtOp
>(op
, dstType
,
821 adaptor
.getOperands());
824 if (getBitWidth(fromType
) > getBitWidth(toType
)) {
825 rewriter
.template replaceOpWithNewOp
<LLVMTruncOp
>(op
, dstType
,
826 adaptor
.getOperands());
833 class FunctionCallPattern
834 : public SPIRVToLLVMConversion
<spirv::FunctionCallOp
> {
836 using SPIRVToLLVMConversion
<spirv::FunctionCallOp
>::SPIRVToLLVMConversion
;
839 matchAndRewrite(spirv::FunctionCallOp callOp
, OpAdaptor adaptor
,
840 ConversionPatternRewriter
&rewriter
) const override
{
841 if (callOp
.getNumResults() == 0) {
842 auto newOp
= rewriter
.replaceOpWithNewOp
<LLVM::CallOp
>(
843 callOp
, std::nullopt
, adaptor
.getOperands(), callOp
->getAttrs());
844 newOp
.getProperties().operandSegmentSizes
= {
845 static_cast<int32_t>(adaptor
.getOperands().size()), 0};
846 newOp
.getProperties().op_bundle_sizes
= rewriter
.getDenseI32ArrayAttr({});
850 // Function returns a single result.
851 auto dstType
= getTypeConverter()->convertType(callOp
.getType(0));
853 return rewriter
.notifyMatchFailure(callOp
, "type conversion failed");
854 auto newOp
= rewriter
.replaceOpWithNewOp
<LLVM::CallOp
>(
855 callOp
, dstType
, adaptor
.getOperands(), callOp
->getAttrs());
856 newOp
.getProperties().operandSegmentSizes
= {
857 static_cast<int32_t>(adaptor
.getOperands().size()), 0};
858 newOp
.getProperties().op_bundle_sizes
= rewriter
.getDenseI32ArrayAttr({});
863 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
864 template <typename SPIRVOp
, LLVM::FCmpPredicate predicate
>
865 class FComparePattern
: public SPIRVToLLVMConversion
<SPIRVOp
> {
867 using SPIRVToLLVMConversion
<SPIRVOp
>::SPIRVToLLVMConversion
;
870 matchAndRewrite(SPIRVOp op
, typename
SPIRVOp::Adaptor adaptor
,
871 ConversionPatternRewriter
&rewriter
) const override
{
873 auto dstType
= this->getTypeConverter()->convertType(op
.getType());
875 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
877 rewriter
.template replaceOpWithNewOp
<LLVM::FCmpOp
>(
878 op
, dstType
, predicate
, op
.getOperand1(), op
.getOperand2());
883 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
884 template <typename SPIRVOp
, LLVM::ICmpPredicate predicate
>
885 class IComparePattern
: public SPIRVToLLVMConversion
<SPIRVOp
> {
887 using SPIRVToLLVMConversion
<SPIRVOp
>::SPIRVToLLVMConversion
;
890 matchAndRewrite(SPIRVOp op
, typename
SPIRVOp::Adaptor adaptor
,
891 ConversionPatternRewriter
&rewriter
) const override
{
893 auto dstType
= this->getTypeConverter()->convertType(op
.getType());
895 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
897 rewriter
.template replaceOpWithNewOp
<LLVM::ICmpOp
>(
898 op
, dstType
, predicate
, op
.getOperand1(), op
.getOperand2());
903 class InverseSqrtPattern
904 : public SPIRVToLLVMConversion
<spirv::GLInverseSqrtOp
> {
906 using SPIRVToLLVMConversion
<spirv::GLInverseSqrtOp
>::SPIRVToLLVMConversion
;
909 matchAndRewrite(spirv::GLInverseSqrtOp op
, OpAdaptor adaptor
,
910 ConversionPatternRewriter
&rewriter
) const override
{
911 auto srcType
= op
.getType();
912 auto dstType
= getTypeConverter()->convertType(srcType
);
914 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
916 Location loc
= op
.getLoc();
917 Value one
= createFPConstant(loc
, srcType
, dstType
, rewriter
, 1.0);
918 Value sqrt
= rewriter
.create
<LLVM::SqrtOp
>(loc
, dstType
, op
.getOperand());
919 rewriter
.replaceOpWithNewOp
<LLVM::FDivOp
>(op
, dstType
, one
, sqrt
);
924 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
925 template <typename SPIRVOp
>
926 class LoadStorePattern
: public SPIRVToLLVMConversion
<SPIRVOp
> {
928 using SPIRVToLLVMConversion
<SPIRVOp
>::SPIRVToLLVMConversion
;
931 matchAndRewrite(SPIRVOp op
, typename
SPIRVOp::Adaptor adaptor
,
932 ConversionPatternRewriter
&rewriter
) const override
{
933 if (!op
.getMemoryAccess()) {
934 return replaceWithLoadOrStore(op
, adaptor
.getOperands(), rewriter
,
935 *this->getTypeConverter(), /*alignment=*/0,
936 /*isVolatile=*/false,
937 /*isNonTemporal=*/false);
939 auto memoryAccess
= *op
.getMemoryAccess();
940 switch (memoryAccess
) {
941 case spirv::MemoryAccess::Aligned
:
942 case spirv::MemoryAccess::None
:
943 case spirv::MemoryAccess::Nontemporal
:
944 case spirv::MemoryAccess::Volatile
: {
946 memoryAccess
== spirv::MemoryAccess::Aligned
? *op
.getAlignment() : 0;
947 bool isNonTemporal
= memoryAccess
== spirv::MemoryAccess::Nontemporal
;
948 bool isVolatile
= memoryAccess
== spirv::MemoryAccess::Volatile
;
949 return replaceWithLoadOrStore(op
, adaptor
.getOperands(), rewriter
,
950 *this->getTypeConverter(), alignment
,
951 isVolatile
, isNonTemporal
);
954 // There is no support of other memory access attributes.
960 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
961 template <typename SPIRVOp
>
962 class NotPattern
: public SPIRVToLLVMConversion
<SPIRVOp
> {
964 using SPIRVToLLVMConversion
<SPIRVOp
>::SPIRVToLLVMConversion
;
967 matchAndRewrite(SPIRVOp notOp
, typename
SPIRVOp::Adaptor adaptor
,
968 ConversionPatternRewriter
&rewriter
) const override
{
969 auto srcType
= notOp
.getType();
970 auto dstType
= this->getTypeConverter()->convertType(srcType
);
972 return rewriter
.notifyMatchFailure(notOp
, "type conversion failed");
974 Location loc
= notOp
.getLoc();
975 IntegerAttr minusOne
= minusOneIntegerAttribute(srcType
, rewriter
);
977 isa
<VectorType
>(srcType
)
978 ? rewriter
.create
<LLVM::ConstantOp
>(
980 SplatElementsAttr::get(cast
<VectorType
>(srcType
), minusOne
))
981 : rewriter
.create
<LLVM::ConstantOp
>(loc
, dstType
, minusOne
);
982 rewriter
.template replaceOpWithNewOp
<LLVM::XOrOp
>(notOp
, dstType
,
983 notOp
.getOperand(), mask
);
988 /// A template pattern that erases the given `SPIRVOp`.
989 template <typename SPIRVOp
>
990 class ErasePattern
: public SPIRVToLLVMConversion
<SPIRVOp
> {
992 using SPIRVToLLVMConversion
<SPIRVOp
>::SPIRVToLLVMConversion
;
995 matchAndRewrite(SPIRVOp op
, typename
SPIRVOp::Adaptor adaptor
,
996 ConversionPatternRewriter
&rewriter
) const override
{
997 rewriter
.eraseOp(op
);
1002 class ReturnPattern
: public SPIRVToLLVMConversion
<spirv::ReturnOp
> {
1004 using SPIRVToLLVMConversion
<spirv::ReturnOp
>::SPIRVToLLVMConversion
;
1007 matchAndRewrite(spirv::ReturnOp returnOp
, OpAdaptor adaptor
,
1008 ConversionPatternRewriter
&rewriter
) const override
{
1009 rewriter
.replaceOpWithNewOp
<LLVM::ReturnOp
>(returnOp
, ArrayRef
<Type
>(),
1015 class ReturnValuePattern
: public SPIRVToLLVMConversion
<spirv::ReturnValueOp
> {
1017 using SPIRVToLLVMConversion
<spirv::ReturnValueOp
>::SPIRVToLLVMConversion
;
1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp
, OpAdaptor adaptor
,
1021 ConversionPatternRewriter
&rewriter
) const override
{
1022 rewriter
.replaceOpWithNewOp
<LLVM::ReturnOp
>(returnValueOp
, ArrayRef
<Type
>(),
1023 adaptor
.getOperands());
1028 static LLVM::LLVMFuncOp
lookupOrCreateSPIRVFn(Operation
*symbolTable
,
1030 ArrayRef
<Type
> paramTypes
,
1032 bool convergent
= true) {
1033 auto func
= dyn_cast_or_null
<LLVM::LLVMFuncOp
>(
1034 SymbolTable::lookupSymbolIn(symbolTable
, name
));
1038 OpBuilder
b(symbolTable
->getRegion(0));
1039 func
= b
.create
<LLVM::LLVMFuncOp
>(
1040 symbolTable
->getLoc(), name
,
1041 LLVM::LLVMFunctionType::get(resultType
, paramTypes
));
1042 func
.setCConv(LLVM::cconv::CConv::SPIR_FUNC
);
1043 func
.setConvergent(convergent
);
1044 func
.setNoUnwind(true);
1045 func
.setWillReturn(true);
1049 static LLVM::CallOp
createSPIRVBuiltinCall(Location loc
, OpBuilder
&builder
,
1050 LLVM::LLVMFuncOp func
,
1052 auto call
= builder
.create
<LLVM::CallOp
>(loc
, func
, args
);
1053 call
.setCConv(func
.getCConv());
1054 call
.setConvergentAttr(func
.getConvergentAttr());
1055 call
.setNoUnwindAttr(func
.getNoUnwindAttr());
1056 call
.setWillReturnAttr(func
.getWillReturnAttr());
1060 template <typename BarrierOpTy
>
1061 class ControlBarrierPattern
: public SPIRVToLLVMConversion
<BarrierOpTy
> {
1063 using OpAdaptor
= typename SPIRVToLLVMConversion
<BarrierOpTy
>::OpAdaptor
;
1065 using SPIRVToLLVMConversion
<BarrierOpTy
>::SPIRVToLLVMConversion
;
1067 static constexpr StringRef
getFuncName();
1070 matchAndRewrite(BarrierOpTy controlBarrierOp
, OpAdaptor adaptor
,
1071 ConversionPatternRewriter
&rewriter
) const override
{
1072 constexpr StringRef funcName
= getFuncName();
1073 Operation
*symbolTable
=
1074 controlBarrierOp
->template getParentWithTrait
<OpTrait::SymbolTable
>();
1076 Type i32
= rewriter
.getI32Type();
1078 Type voidTy
= rewriter
.getType
<LLVM::LLVMVoidType
>();
1079 LLVM::LLVMFuncOp func
=
1080 lookupOrCreateSPIRVFn(symbolTable
, funcName
, {i32
, i32
, i32
}, voidTy
);
1082 Location loc
= controlBarrierOp
->getLoc();
1083 Value execution
= rewriter
.create
<LLVM::ConstantOp
>(
1084 loc
, i32
, static_cast<int32_t>(adaptor
.getExecutionScope()));
1085 Value memory
= rewriter
.create
<LLVM::ConstantOp
>(
1086 loc
, i32
, static_cast<int32_t>(adaptor
.getMemoryScope()));
1087 Value semantics
= rewriter
.create
<LLVM::ConstantOp
>(
1088 loc
, i32
, static_cast<int32_t>(adaptor
.getMemorySemantics()));
1090 auto call
= createSPIRVBuiltinCall(loc
, rewriter
, func
,
1091 {execution
, memory
, semantics
});
1093 rewriter
.replaceOp(controlBarrierOp
, call
);
1100 StringRef
getTypeMangling(Type type
, bool isSigned
) {
1101 return llvm::TypeSwitch
<Type
, StringRef
>(type
)
1102 .Case
<Float16Type
>([](auto) { return "Dh"; })
1103 .Case
<Float32Type
>([](auto) { return "f"; })
1104 .Case
<Float64Type
>([](auto) { return "d"; })
1105 .Case
<IntegerType
>([isSigned
](IntegerType intTy
) {
1106 switch (intTy
.getWidth()) {
1110 return (isSigned
) ? "a" : "c";
1112 return (isSigned
) ? "s" : "t";
1114 return (isSigned
) ? "i" : "j";
1116 return (isSigned
) ? "l" : "m";
1118 llvm_unreachable("Unsupported integer width");
1122 llvm_unreachable("No mangling defined");
1127 template <typename ReduceOp
>
1128 constexpr StringLiteral
getGroupFuncName();
1131 constexpr StringLiteral getGroupFuncName
<spirv::GroupIAddOp
>() {
1132 return "_Z17__spirv_GroupIAddii";
1135 constexpr StringLiteral getGroupFuncName
<spirv::GroupFAddOp
>() {
1136 return "_Z17__spirv_GroupFAddii";
1139 constexpr StringLiteral getGroupFuncName
<spirv::GroupSMinOp
>() {
1140 return "_Z17__spirv_GroupSMinii";
1143 constexpr StringLiteral getGroupFuncName
<spirv::GroupUMinOp
>() {
1144 return "_Z17__spirv_GroupUMinii";
1147 constexpr StringLiteral getGroupFuncName
<spirv::GroupFMinOp
>() {
1148 return "_Z17__spirv_GroupFMinii";
1151 constexpr StringLiteral getGroupFuncName
<spirv::GroupSMaxOp
>() {
1152 return "_Z17__spirv_GroupSMaxii";
1155 constexpr StringLiteral getGroupFuncName
<spirv::GroupUMaxOp
>() {
1156 return "_Z17__spirv_GroupUMaxii";
1159 constexpr StringLiteral getGroupFuncName
<spirv::GroupFMaxOp
>() {
1160 return "_Z17__spirv_GroupFMaxii";
1163 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformIAddOp
>() {
1164 return "_Z27__spirv_GroupNonUniformIAddii";
1167 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformFAddOp
>() {
1168 return "_Z27__spirv_GroupNonUniformFAddii";
1171 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformIMulOp
>() {
1172 return "_Z27__spirv_GroupNonUniformIMulii";
1175 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformFMulOp
>() {
1176 return "_Z27__spirv_GroupNonUniformFMulii";
1179 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformSMinOp
>() {
1180 return "_Z27__spirv_GroupNonUniformSMinii";
1183 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformUMinOp
>() {
1184 return "_Z27__spirv_GroupNonUniformUMinii";
1187 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformFMinOp
>() {
1188 return "_Z27__spirv_GroupNonUniformFMinii";
1191 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformSMaxOp
>() {
1192 return "_Z27__spirv_GroupNonUniformSMaxii";
1195 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformUMaxOp
>() {
1196 return "_Z27__spirv_GroupNonUniformUMaxii";
1199 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformFMaxOp
>() {
1200 return "_Z27__spirv_GroupNonUniformFMaxii";
1203 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformBitwiseAndOp
>() {
1204 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1207 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformBitwiseOrOp
>() {
1208 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1211 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformBitwiseXorOp
>() {
1212 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1215 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformLogicalAndOp
>() {
1216 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1219 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformLogicalOrOp
>() {
1220 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1223 constexpr StringLiteral getGroupFuncName
<spirv::GroupNonUniformLogicalXorOp
>() {
1224 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1228 template <typename ReduceOp
, bool Signed
= false, bool NonUniform
= false>
1229 class GroupReducePattern
: public SPIRVToLLVMConversion
<ReduceOp
> {
1231 using SPIRVToLLVMConversion
<ReduceOp
>::SPIRVToLLVMConversion
;
1234 matchAndRewrite(ReduceOp op
, typename
ReduceOp::Adaptor adaptor
,
1235 ConversionPatternRewriter
&rewriter
) const override
{
1237 Type retTy
= op
.getResult().getType();
1238 if (!retTy
.isIntOrFloat()) {
1241 SmallString
<36> funcName
= getGroupFuncName
<ReduceOp
>();
1242 funcName
+= getTypeMangling(retTy
, false);
1244 Type i32Ty
= rewriter
.getI32Type();
1245 SmallVector
<Type
> paramTypes
{i32Ty
, i32Ty
, retTy
};
1246 if constexpr (NonUniform
) {
1247 if (adaptor
.getClusterSize()) {
1249 paramTypes
.push_back(i32Ty
);
1253 Operation
*symbolTable
=
1254 op
->template getParentWithTrait
<OpTrait::SymbolTable
>();
1256 LLVM::LLVMFuncOp func
= lookupOrCreateSPIRVFn(
1257 symbolTable
, funcName
, paramTypes
, retTy
, !NonUniform
);
1259 Location loc
= op
.getLoc();
1260 Value scope
= rewriter
.create
<LLVM::ConstantOp
>(
1261 loc
, i32Ty
, static_cast<int32_t>(adaptor
.getExecutionScope()));
1262 Value groupOp
= rewriter
.create
<LLVM::ConstantOp
>(
1263 loc
, i32Ty
, static_cast<int32_t>(adaptor
.getGroupOperation()));
1264 SmallVector
<Value
> operands
{scope
, groupOp
};
1265 operands
.append(adaptor
.getOperands().begin(), adaptor
.getOperands().end());
1267 auto call
= createSPIRVBuiltinCall(loc
, rewriter
, func
, operands
);
1268 rewriter
.replaceOp(op
, call
);
1275 ControlBarrierPattern
<spirv::ControlBarrierOp
>::getFuncName() {
1276 return "_Z22__spirv_ControlBarrieriii";
1281 ControlBarrierPattern
<spirv::INTELControlBarrierArriveOp
>::getFuncName() {
1282 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1287 ControlBarrierPattern
<spirv::INTELControlBarrierWaitOp
>::getFuncName() {
1288 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1291 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1292 /// should be reachable for conversion to succeed. The structure of the loop in
1293 /// LLVM dialect will be the following:
1295 /// +------------------------------------+
1296 /// | <code before spirv.mlir.loop> |
1297 /// | llvm.br ^header |
1298 /// +------------------------------------+
1300 /// +----------------+ |
1303 /// | +------------------------------------+
1305 /// | | <header code> |
1306 /// | | llvm.cond_br %cond, ^body, ^exit |
1307 /// | +------------------------------------+
1309 /// | |----------------------+
1312 /// | +------------------------------------+ |
1314 /// | | <body code> | |
1315 /// | | llvm.br ^continue | |
1316 /// | +------------------------------------+ |
1319 /// | +------------------------------------+ |
1320 /// | | ^continue: | |
1321 /// | | <continue code> | |
1322 /// | | llvm.br ^header | |
1323 /// | +------------------------------------+ |
1325 /// +---------------+ +----------------------+
1328 /// +------------------------------------+
1330 /// | llvm.br ^remaining |
1331 /// +------------------------------------+
1334 /// +------------------------------------+
1336 /// | <code after spirv.mlir.loop> |
1337 /// +------------------------------------+
1339 class LoopPattern
: public SPIRVToLLVMConversion
<spirv::LoopOp
> {
1341 using SPIRVToLLVMConversion
<spirv::LoopOp
>::SPIRVToLLVMConversion
;
1344 matchAndRewrite(spirv::LoopOp loopOp
, OpAdaptor adaptor
,
1345 ConversionPatternRewriter
&rewriter
) const override
{
1346 // There is no support of loop control at the moment.
1347 if (loopOp
.getLoopControl() != spirv::LoopControl::None
)
1350 // `spirv.mlir.loop` with empty region is redundant and should be erased.
1351 if (loopOp
.getBody().empty()) {
1352 rewriter
.eraseOp(loopOp
);
1356 Location loc
= loopOp
.getLoc();
1358 // Split the current block after `spirv.mlir.loop`. The remaining ops will
1359 // be used in `endBlock`.
1360 Block
*currentBlock
= rewriter
.getBlock();
1361 auto position
= Block::iterator(loopOp
);
1362 Block
*endBlock
= rewriter
.splitBlock(currentBlock
, position
);
1364 // Remove entry block and create a branch in the current block going to the
1366 Block
*entryBlock
= loopOp
.getEntryBlock();
1367 assert(entryBlock
->getOperations().size() == 1);
1368 auto brOp
= dyn_cast
<spirv::BranchOp
>(entryBlock
->getOperations().front());
1371 Block
*headerBlock
= loopOp
.getHeaderBlock();
1372 rewriter
.setInsertionPointToEnd(currentBlock
);
1373 rewriter
.create
<LLVM::BrOp
>(loc
, brOp
.getBlockArguments(), headerBlock
);
1374 rewriter
.eraseBlock(entryBlock
);
1376 // Branch from merge block to end block.
1377 Block
*mergeBlock
= loopOp
.getMergeBlock();
1378 Operation
*terminator
= mergeBlock
->getTerminator();
1379 ValueRange terminatorOperands
= terminator
->getOperands();
1380 rewriter
.setInsertionPointToEnd(mergeBlock
);
1381 rewriter
.create
<LLVM::BrOp
>(loc
, terminatorOperands
, endBlock
);
1383 rewriter
.inlineRegionBefore(loopOp
.getBody(), endBlock
);
1384 rewriter
.replaceOp(loopOp
, endBlock
->getArguments());
1389 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1390 /// block. All blocks within selection should be reachable for conversion to
1392 class SelectionPattern
: public SPIRVToLLVMConversion
<spirv::SelectionOp
> {
1394 using SPIRVToLLVMConversion
<spirv::SelectionOp
>::SPIRVToLLVMConversion
;
1397 matchAndRewrite(spirv::SelectionOp op
, OpAdaptor adaptor
,
1398 ConversionPatternRewriter
&rewriter
) const override
{
1399 // There is no support for `Flatten` or `DontFlatten` selection control at
1400 // the moment. This are just compiler hints and can be performed during the
1401 // optimization passes.
1402 if (op
.getSelectionControl() != spirv::SelectionControl::None
)
1405 // `spirv.mlir.selection` should have at least two blocks: one selection
1406 // header block and one merge block. If no blocks are present, or control
1407 // flow branches straight to merge block (two blocks are present), the op is
1408 // redundant and it is erased.
1409 if (op
.getBody().getBlocks().size() <= 2) {
1410 rewriter
.eraseOp(op
);
1414 Location loc
= op
.getLoc();
1416 // Split the current block after `spirv.mlir.selection`. The remaining ops
1417 // will be used in `continueBlock`.
1418 auto *currentBlock
= rewriter
.getInsertionBlock();
1419 rewriter
.setInsertionPointAfter(op
);
1420 auto position
= rewriter
.getInsertionPoint();
1421 auto *continueBlock
= rewriter
.splitBlock(currentBlock
, position
);
1423 // Extract conditional branch information from the header block. By SPIR-V
1424 // dialect spec, it should contain `spirv.BranchConditional` or
1425 // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1426 // moment in the SPIR-V dialect. Remove this block when finished.
1427 auto *headerBlock
= op
.getHeaderBlock();
1428 assert(headerBlock
->getOperations().size() == 1);
1429 auto condBrOp
= dyn_cast
<spirv::BranchConditionalOp
>(
1430 headerBlock
->getOperations().front());
1433 rewriter
.eraseBlock(headerBlock
);
1435 // Branch from merge block to continue block.
1436 auto *mergeBlock
= op
.getMergeBlock();
1437 Operation
*terminator
= mergeBlock
->getTerminator();
1438 ValueRange terminatorOperands
= terminator
->getOperands();
1439 rewriter
.setInsertionPointToEnd(mergeBlock
);
1440 rewriter
.create
<LLVM::BrOp
>(loc
, terminatorOperands
, continueBlock
);
1442 // Link current block to `true` and `false` blocks within the selection.
1443 Block
*trueBlock
= condBrOp
.getTrueBlock();
1444 Block
*falseBlock
= condBrOp
.getFalseBlock();
1445 rewriter
.setInsertionPointToEnd(currentBlock
);
1446 rewriter
.create
<LLVM::CondBrOp
>(loc
, condBrOp
.getCondition(), trueBlock
,
1447 condBrOp
.getTrueTargetOperands(),
1449 condBrOp
.getFalseTargetOperands());
1451 rewriter
.inlineRegionBefore(op
.getBody(), continueBlock
);
1452 rewriter
.replaceOp(op
, continueBlock
->getArguments());
1457 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1458 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1459 /// `Shift` is zero or sign extended to match this specification. Cases when
1460 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1461 template <typename SPIRVOp
, typename LLVMOp
>
1462 class ShiftPattern
: public SPIRVToLLVMConversion
<SPIRVOp
> {
1464 using SPIRVToLLVMConversion
<SPIRVOp
>::SPIRVToLLVMConversion
;
1467 matchAndRewrite(SPIRVOp op
, typename
SPIRVOp::Adaptor adaptor
,
1468 ConversionPatternRewriter
&rewriter
) const override
{
1470 auto dstType
= this->getTypeConverter()->convertType(op
.getType());
1472 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
1474 Type op1Type
= op
.getOperand1().getType();
1475 Type op2Type
= op
.getOperand2().getType();
1477 if (op1Type
== op2Type
) {
1478 rewriter
.template replaceOpWithNewOp
<LLVMOp
>(op
, dstType
,
1479 adaptor
.getOperands());
1483 std::optional
<uint64_t> dstTypeWidth
=
1484 getIntegerOrVectorElementWidth(dstType
);
1485 std::optional
<uint64_t> op2TypeWidth
=
1486 getIntegerOrVectorElementWidth(op2Type
);
1488 if (!dstTypeWidth
|| !op2TypeWidth
)
1491 Location loc
= op
.getLoc();
1493 if (op2TypeWidth
< dstTypeWidth
) {
1494 if (isUnsignedIntegerOrVector(op2Type
)) {
1495 extended
= rewriter
.template create
<LLVM::ZExtOp
>(
1496 loc
, dstType
, adaptor
.getOperand2());
1498 extended
= rewriter
.template create
<LLVM::SExtOp
>(
1499 loc
, dstType
, adaptor
.getOperand2());
1501 } else if (op2TypeWidth
== dstTypeWidth
) {
1502 extended
= adaptor
.getOperand2();
1507 Value result
= rewriter
.template create
<LLVMOp
>(
1508 loc
, dstType
, adaptor
.getOperand1(), extended
);
1509 rewriter
.replaceOp(op
, result
);
1514 class TanPattern
: public SPIRVToLLVMConversion
<spirv::GLTanOp
> {
1516 using SPIRVToLLVMConversion
<spirv::GLTanOp
>::SPIRVToLLVMConversion
;
1519 matchAndRewrite(spirv::GLTanOp tanOp
, OpAdaptor adaptor
,
1520 ConversionPatternRewriter
&rewriter
) const override
{
1521 auto dstType
= getTypeConverter()->convertType(tanOp
.getType());
1523 return rewriter
.notifyMatchFailure(tanOp
, "type conversion failed");
1525 Location loc
= tanOp
.getLoc();
1526 Value sin
= rewriter
.create
<LLVM::SinOp
>(loc
, dstType
, tanOp
.getOperand());
1527 Value cos
= rewriter
.create
<LLVM::CosOp
>(loc
, dstType
, tanOp
.getOperand());
1528 rewriter
.replaceOpWithNewOp
<LLVM::FDivOp
>(tanOp
, dstType
, sin
, cos
);
1533 /// Convert `spirv.Tanh` to
1539 class TanhPattern
: public SPIRVToLLVMConversion
<spirv::GLTanhOp
> {
1541 using SPIRVToLLVMConversion
<spirv::GLTanhOp
>::SPIRVToLLVMConversion
;
1544 matchAndRewrite(spirv::GLTanhOp tanhOp
, OpAdaptor adaptor
,
1545 ConversionPatternRewriter
&rewriter
) const override
{
1546 auto srcType
= tanhOp
.getType();
1547 auto dstType
= getTypeConverter()->convertType(srcType
);
1549 return rewriter
.notifyMatchFailure(tanhOp
, "type conversion failed");
1551 Location loc
= tanhOp
.getLoc();
1552 Value two
= createFPConstant(loc
, srcType
, dstType
, rewriter
, 2.0);
1554 rewriter
.create
<LLVM::FMulOp
>(loc
, dstType
, two
, tanhOp
.getOperand());
1555 Value exponential
= rewriter
.create
<LLVM::ExpOp
>(loc
, dstType
, multiplied
);
1556 Value one
= createFPConstant(loc
, srcType
, dstType
, rewriter
, 1.0);
1558 rewriter
.create
<LLVM::FSubOp
>(loc
, dstType
, exponential
, one
);
1560 rewriter
.create
<LLVM::FAddOp
>(loc
, dstType
, exponential
, one
);
1561 rewriter
.replaceOpWithNewOp
<LLVM::FDivOp
>(tanhOp
, dstType
, numerator
,
1567 class VariablePattern
: public SPIRVToLLVMConversion
<spirv::VariableOp
> {
1569 using SPIRVToLLVMConversion
<spirv::VariableOp
>::SPIRVToLLVMConversion
;
1572 matchAndRewrite(spirv::VariableOp varOp
, OpAdaptor adaptor
,
1573 ConversionPatternRewriter
&rewriter
) const override
{
1574 auto srcType
= varOp
.getType();
1575 // Initialization is supported for scalars and vectors only.
1576 auto pointerTo
= cast
<spirv::PointerType
>(srcType
).getPointeeType();
1577 auto init
= varOp
.getInitializer();
1578 if (init
&& !pointerTo
.isIntOrFloat() && !isa
<VectorType
>(pointerTo
))
1581 auto dstType
= getTypeConverter()->convertType(srcType
);
1583 return rewriter
.notifyMatchFailure(varOp
, "type conversion failed");
1585 Location loc
= varOp
.getLoc();
1586 Value size
= createI32ConstantOf(loc
, rewriter
, 1);
1588 auto elementType
= getTypeConverter()->convertType(pointerTo
);
1590 return rewriter
.notifyMatchFailure(varOp
, "type conversion failed");
1591 rewriter
.replaceOpWithNewOp
<LLVM::AllocaOp
>(varOp
, dstType
, elementType
,
1595 auto elementType
= getTypeConverter()->convertType(pointerTo
);
1597 return rewriter
.notifyMatchFailure(varOp
, "type conversion failed");
1599 rewriter
.create
<LLVM::AllocaOp
>(loc
, dstType
, elementType
, size
);
1600 rewriter
.create
<LLVM::StoreOp
>(loc
, adaptor
.getInitializer(), allocated
);
1601 rewriter
.replaceOp(varOp
, allocated
);
1606 //===----------------------------------------------------------------------===//
1607 // BitcastOp conversion
1608 //===----------------------------------------------------------------------===//
1610 class BitcastConversionPattern
1611 : public SPIRVToLLVMConversion
<spirv::BitcastOp
> {
1613 using SPIRVToLLVMConversion
<spirv::BitcastOp
>::SPIRVToLLVMConversion
;
1616 matchAndRewrite(spirv::BitcastOp bitcastOp
, OpAdaptor adaptor
,
1617 ConversionPatternRewriter
&rewriter
) const override
{
1618 auto dstType
= getTypeConverter()->convertType(bitcastOp
.getType());
1620 return rewriter
.notifyMatchFailure(bitcastOp
, "type conversion failed");
1622 // LLVM's opaque pointers do not require bitcasts.
1623 if (isa
<LLVM::LLVMPointerType
>(dstType
)) {
1624 rewriter
.replaceOp(bitcastOp
, adaptor
.getOperand());
1628 rewriter
.replaceOpWithNewOp
<LLVM::BitcastOp
>(
1629 bitcastOp
, dstType
, adaptor
.getOperands(), bitcastOp
->getAttrs());
1634 //===----------------------------------------------------------------------===//
1635 // FuncOp conversion
1636 //===----------------------------------------------------------------------===//
1638 class FuncConversionPattern
: public SPIRVToLLVMConversion
<spirv::FuncOp
> {
1640 using SPIRVToLLVMConversion
<spirv::FuncOp
>::SPIRVToLLVMConversion
;
1643 matchAndRewrite(spirv::FuncOp funcOp
, OpAdaptor adaptor
,
1644 ConversionPatternRewriter
&rewriter
) const override
{
1646 // Convert function signature. At the moment LLVMType converter is enough
1647 // for currently supported types.
1648 auto funcType
= funcOp
.getFunctionType();
1649 TypeConverter::SignatureConversion
signatureConverter(
1650 funcType
.getNumInputs());
1651 auto llvmType
= static_cast<const LLVMTypeConverter
*>(getTypeConverter())
1652 ->convertFunctionSignature(
1653 funcType
, /*isVariadic=*/false,
1654 /*useBarePtrCallConv=*/false, signatureConverter
);
1658 // Create a new `LLVMFuncOp`
1659 Location loc
= funcOp
.getLoc();
1660 StringRef name
= funcOp
.getName();
1661 auto newFuncOp
= rewriter
.create
<LLVM::LLVMFuncOp
>(loc
, name
, llvmType
);
1663 // Convert SPIR-V Function Control to equivalent LLVM function attribute
1664 MLIRContext
*context
= funcOp
.getContext();
1665 switch (funcOp
.getFunctionControl()) {
1666 case spirv::FunctionControl::Inline
:
1667 newFuncOp
.setAlwaysInline(true);
1669 case spirv::FunctionControl::DontInline
:
1670 newFuncOp
.setNoInline(true);
1673 #define DISPATCH(functionControl, llvmAttr) \
1674 case functionControl: \
1675 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1678 DISPATCH(spirv::FunctionControl::Pure
,
1679 StringAttr::get(context
, "readonly"));
1680 DISPATCH(spirv::FunctionControl::Const
,
1681 StringAttr::get(context
, "readnone"));
1685 // Default: if `spirv::FunctionControl::None`, then no attributes are
1691 rewriter
.inlineRegionBefore(funcOp
.getBody(), newFuncOp
.getBody(),
1693 if (failed(rewriter
.convertRegionTypes(
1694 &newFuncOp
.getBody(), *getTypeConverter(), &signatureConverter
))) {
1697 rewriter
.eraseOp(funcOp
);
1702 //===----------------------------------------------------------------------===//
1703 // ModuleOp conversion
1704 //===----------------------------------------------------------------------===//
1706 class ModuleConversionPattern
: public SPIRVToLLVMConversion
<spirv::ModuleOp
> {
1708 using SPIRVToLLVMConversion
<spirv::ModuleOp
>::SPIRVToLLVMConversion
;
1711 matchAndRewrite(spirv::ModuleOp spvModuleOp
, OpAdaptor adaptor
,
1712 ConversionPatternRewriter
&rewriter
) const override
{
1715 rewriter
.create
<ModuleOp
>(spvModuleOp
.getLoc(), spvModuleOp
.getName());
1716 rewriter
.inlineRegionBefore(spvModuleOp
.getRegion(), newModuleOp
.getBody());
1718 // Remove the terminator block that was automatically added by builder
1719 rewriter
.eraseBlock(&newModuleOp
.getBodyRegion().back());
1720 rewriter
.eraseOp(spvModuleOp
);
1725 //===----------------------------------------------------------------------===//
1726 // VectorShuffleOp conversion
1727 //===----------------------------------------------------------------------===//
1729 class VectorShufflePattern
1730 : public SPIRVToLLVMConversion
<spirv::VectorShuffleOp
> {
1732 using SPIRVToLLVMConversion
<spirv::VectorShuffleOp
>::SPIRVToLLVMConversion
;
1734 matchAndRewrite(spirv::VectorShuffleOp op
, OpAdaptor adaptor
,
1735 ConversionPatternRewriter
&rewriter
) const override
{
1736 Location loc
= op
.getLoc();
1737 auto components
= adaptor
.getComponents();
1738 auto vector1
= adaptor
.getVector1();
1739 auto vector2
= adaptor
.getVector2();
1740 int vector1Size
= cast
<VectorType
>(vector1
.getType()).getNumElements();
1741 int vector2Size
= cast
<VectorType
>(vector2
.getType()).getNumElements();
1742 if (vector1Size
== vector2Size
) {
1743 rewriter
.replaceOpWithNewOp
<LLVM::ShuffleVectorOp
>(
1744 op
, vector1
, vector2
,
1745 LLVM::convertArrayToIndices
<int32_t>(components
));
1749 auto dstType
= getTypeConverter()->convertType(op
.getType());
1751 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
1752 auto scalarType
= cast
<VectorType
>(dstType
).getElementType();
1753 auto componentsArray
= components
.getValue();
1754 auto *context
= rewriter
.getContext();
1755 auto llvmI32Type
= IntegerType::get(context
, 32);
1756 Value targetOp
= rewriter
.create
<LLVM::UndefOp
>(loc
, dstType
);
1757 for (unsigned i
= 0; i
< componentsArray
.size(); i
++) {
1758 if (!isa
<IntegerAttr
>(componentsArray
[i
]))
1759 return op
.emitError("unable to support non-constant component");
1761 int indexVal
= cast
<IntegerAttr
>(componentsArray
[i
]).getInt();
1766 Value baseVector
= vector1
;
1767 if (indexVal
>= vector1Size
) {
1768 offsetVal
= vector1Size
;
1769 baseVector
= vector2
;
1772 Value dstIndex
= rewriter
.create
<LLVM::ConstantOp
>(
1773 loc
, llvmI32Type
, rewriter
.getIntegerAttr(rewriter
.getI32Type(), i
));
1774 Value index
= rewriter
.create
<LLVM::ConstantOp
>(
1776 rewriter
.getIntegerAttr(rewriter
.getI32Type(), indexVal
- offsetVal
));
1778 auto extractOp
= rewriter
.create
<LLVM::ExtractElementOp
>(
1779 loc
, scalarType
, baseVector
, index
);
1780 targetOp
= rewriter
.create
<LLVM::InsertElementOp
>(loc
, dstType
, targetOp
,
1781 extractOp
, dstIndex
);
1783 rewriter
.replaceOp(op
, targetOp
);
1789 //===----------------------------------------------------------------------===//
1790 // Pattern population
1791 //===----------------------------------------------------------------------===//
1793 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter
&typeConverter
,
1794 spirv::ClientAPI clientAPI
) {
1795 typeConverter
.addConversion([&](spirv::ArrayType type
) {
1796 return convertArrayType(type
, typeConverter
);
1798 typeConverter
.addConversion([&, clientAPI
](spirv::PointerType type
) {
1799 return convertPointerType(type
, typeConverter
, clientAPI
);
1801 typeConverter
.addConversion([&](spirv::RuntimeArrayType type
) {
1802 return convertRuntimeArrayType(type
, typeConverter
);
1804 typeConverter
.addConversion([&](spirv::StructType type
) {
1805 return convertStructType(type
, typeConverter
);
1809 void mlir::populateSPIRVToLLVMConversionPatterns(
1810 const LLVMTypeConverter
&typeConverter
, RewritePatternSet
&patterns
,
1811 spirv::ClientAPI clientAPI
) {
1814 DirectConversionPattern
<spirv::IAddOp
, LLVM::AddOp
>,
1815 DirectConversionPattern
<spirv::IMulOp
, LLVM::MulOp
>,
1816 DirectConversionPattern
<spirv::ISubOp
, LLVM::SubOp
>,
1817 DirectConversionPattern
<spirv::FAddOp
, LLVM::FAddOp
>,
1818 DirectConversionPattern
<spirv::FDivOp
, LLVM::FDivOp
>,
1819 DirectConversionPattern
<spirv::FMulOp
, LLVM::FMulOp
>,
1820 DirectConversionPattern
<spirv::FNegateOp
, LLVM::FNegOp
>,
1821 DirectConversionPattern
<spirv::FRemOp
, LLVM::FRemOp
>,
1822 DirectConversionPattern
<spirv::FSubOp
, LLVM::FSubOp
>,
1823 DirectConversionPattern
<spirv::SDivOp
, LLVM::SDivOp
>,
1824 DirectConversionPattern
<spirv::SRemOp
, LLVM::SRemOp
>,
1825 DirectConversionPattern
<spirv::UDivOp
, LLVM::UDivOp
>,
1826 DirectConversionPattern
<spirv::UModOp
, LLVM::URemOp
>,
1829 BitFieldInsertPattern
, BitFieldUExtractPattern
, BitFieldSExtractPattern
,
1830 DirectConversionPattern
<spirv::BitCountOp
, LLVM::CtPopOp
>,
1831 DirectConversionPattern
<spirv::BitReverseOp
, LLVM::BitReverseOp
>,
1832 DirectConversionPattern
<spirv::BitwiseAndOp
, LLVM::AndOp
>,
1833 DirectConversionPattern
<spirv::BitwiseOrOp
, LLVM::OrOp
>,
1834 DirectConversionPattern
<spirv::BitwiseXorOp
, LLVM::XOrOp
>,
1835 NotPattern
<spirv::NotOp
>,
1838 BitcastConversionPattern
,
1839 DirectConversionPattern
<spirv::ConvertFToSOp
, LLVM::FPToSIOp
>,
1840 DirectConversionPattern
<spirv::ConvertFToUOp
, LLVM::FPToUIOp
>,
1841 DirectConversionPattern
<spirv::ConvertSToFOp
, LLVM::SIToFPOp
>,
1842 DirectConversionPattern
<spirv::ConvertUToFOp
, LLVM::UIToFPOp
>,
1843 IndirectCastPattern
<spirv::FConvertOp
, LLVM::FPExtOp
, LLVM::FPTruncOp
>,
1844 IndirectCastPattern
<spirv::SConvertOp
, LLVM::SExtOp
, LLVM::TruncOp
>,
1845 IndirectCastPattern
<spirv::UConvertOp
, LLVM::ZExtOp
, LLVM::TruncOp
>,
1848 IComparePattern
<spirv::IEqualOp
, LLVM::ICmpPredicate::eq
>,
1849 IComparePattern
<spirv::INotEqualOp
, LLVM::ICmpPredicate::ne
>,
1850 FComparePattern
<spirv::FOrdEqualOp
, LLVM::FCmpPredicate::oeq
>,
1851 FComparePattern
<spirv::FOrdGreaterThanOp
, LLVM::FCmpPredicate::ogt
>,
1852 FComparePattern
<spirv::FOrdGreaterThanEqualOp
, LLVM::FCmpPredicate::oge
>,
1853 FComparePattern
<spirv::FOrdLessThanEqualOp
, LLVM::FCmpPredicate::ole
>,
1854 FComparePattern
<spirv::FOrdLessThanOp
, LLVM::FCmpPredicate::olt
>,
1855 FComparePattern
<spirv::FOrdNotEqualOp
, LLVM::FCmpPredicate::one
>,
1856 FComparePattern
<spirv::FUnordEqualOp
, LLVM::FCmpPredicate::ueq
>,
1857 FComparePattern
<spirv::FUnordGreaterThanOp
, LLVM::FCmpPredicate::ugt
>,
1858 FComparePattern
<spirv::FUnordGreaterThanEqualOp
,
1859 LLVM::FCmpPredicate::uge
>,
1860 FComparePattern
<spirv::FUnordLessThanEqualOp
, LLVM::FCmpPredicate::ule
>,
1861 FComparePattern
<spirv::FUnordLessThanOp
, LLVM::FCmpPredicate::ult
>,
1862 FComparePattern
<spirv::FUnordNotEqualOp
, LLVM::FCmpPredicate::une
>,
1863 IComparePattern
<spirv::SGreaterThanOp
, LLVM::ICmpPredicate::sgt
>,
1864 IComparePattern
<spirv::SGreaterThanEqualOp
, LLVM::ICmpPredicate::sge
>,
1865 IComparePattern
<spirv::SLessThanEqualOp
, LLVM::ICmpPredicate::sle
>,
1866 IComparePattern
<spirv::SLessThanOp
, LLVM::ICmpPredicate::slt
>,
1867 IComparePattern
<spirv::UGreaterThanOp
, LLVM::ICmpPredicate::ugt
>,
1868 IComparePattern
<spirv::UGreaterThanEqualOp
, LLVM::ICmpPredicate::uge
>,
1869 IComparePattern
<spirv::ULessThanEqualOp
, LLVM::ICmpPredicate::ule
>,
1870 IComparePattern
<spirv::ULessThanOp
, LLVM::ICmpPredicate::ult
>,
1873 ConstantScalarAndVectorPattern
,
1876 BranchConversionPattern
, BranchConditionalConversionPattern
,
1877 FunctionCallPattern
, LoopPattern
, SelectionPattern
,
1878 ErasePattern
<spirv::MergeOp
>,
1880 // Entry points and execution mode are handled separately.
1881 ErasePattern
<spirv::EntryPointOp
>, ExecutionModePattern
,
1883 // GLSL extended instruction set ops
1884 DirectConversionPattern
<spirv::GLCeilOp
, LLVM::FCeilOp
>,
1885 DirectConversionPattern
<spirv::GLCosOp
, LLVM::CosOp
>,
1886 DirectConversionPattern
<spirv::GLExpOp
, LLVM::ExpOp
>,
1887 DirectConversionPattern
<spirv::GLFAbsOp
, LLVM::FAbsOp
>,
1888 DirectConversionPattern
<spirv::GLFloorOp
, LLVM::FFloorOp
>,
1889 DirectConversionPattern
<spirv::GLFMaxOp
, LLVM::MaxNumOp
>,
1890 DirectConversionPattern
<spirv::GLFMinOp
, LLVM::MinNumOp
>,
1891 DirectConversionPattern
<spirv::GLLogOp
, LLVM::LogOp
>,
1892 DirectConversionPattern
<spirv::GLSinOp
, LLVM::SinOp
>,
1893 DirectConversionPattern
<spirv::GLSMaxOp
, LLVM::SMaxOp
>,
1894 DirectConversionPattern
<spirv::GLSMinOp
, LLVM::SMinOp
>,
1895 DirectConversionPattern
<spirv::GLSqrtOp
, LLVM::SqrtOp
>,
1896 InverseSqrtPattern
, TanPattern
, TanhPattern
,
1899 DirectConversionPattern
<spirv::LogicalAndOp
, LLVM::AndOp
>,
1900 DirectConversionPattern
<spirv::LogicalOrOp
, LLVM::OrOp
>,
1901 IComparePattern
<spirv::LogicalEqualOp
, LLVM::ICmpPredicate::eq
>,
1902 IComparePattern
<spirv::LogicalNotEqualOp
, LLVM::ICmpPredicate::ne
>,
1903 NotPattern
<spirv::LogicalNotOp
>,
1906 AccessChainPattern
, AddressOfPattern
, LoadStorePattern
<spirv::LoadOp
>,
1907 LoadStorePattern
<spirv::StoreOp
>, VariablePattern
,
1909 // Miscellaneous ops
1910 CompositeExtractPattern
, CompositeInsertPattern
,
1911 DirectConversionPattern
<spirv::SelectOp
, LLVM::SelectOp
>,
1912 DirectConversionPattern
<spirv::UndefOp
, LLVM::UndefOp
>,
1913 VectorShufflePattern
,
1916 ShiftPattern
<spirv::ShiftRightArithmeticOp
, LLVM::AShrOp
>,
1917 ShiftPattern
<spirv::ShiftRightLogicalOp
, LLVM::LShrOp
>,
1918 ShiftPattern
<spirv::ShiftLeftLogicalOp
, LLVM::ShlOp
>,
1921 ReturnPattern
, ReturnValuePattern
,
1924 ControlBarrierPattern
<spirv::ControlBarrierOp
>,
1925 ControlBarrierPattern
<spirv::INTELControlBarrierArriveOp
>,
1926 ControlBarrierPattern
<spirv::INTELControlBarrierWaitOp
>,
1928 // Group reduction operations
1929 GroupReducePattern
<spirv::GroupIAddOp
>,
1930 GroupReducePattern
<spirv::GroupFAddOp
>,
1931 GroupReducePattern
<spirv::GroupFMinOp
>,
1932 GroupReducePattern
<spirv::GroupUMinOp
>,
1933 GroupReducePattern
<spirv::GroupSMinOp
, /*Signed=*/true>,
1934 GroupReducePattern
<spirv::GroupFMaxOp
>,
1935 GroupReducePattern
<spirv::GroupUMaxOp
>,
1936 GroupReducePattern
<spirv::GroupSMaxOp
, /*Signed=*/true>,
1937 GroupReducePattern
<spirv::GroupNonUniformIAddOp
, /*Signed=*/false,
1938 /*NonUniform=*/true>,
1939 GroupReducePattern
<spirv::GroupNonUniformFAddOp
, /*Signed=*/false,
1940 /*NonUniform=*/true>,
1941 GroupReducePattern
<spirv::GroupNonUniformIMulOp
, /*Signed=*/false,
1942 /*NonUniform=*/true>,
1943 GroupReducePattern
<spirv::GroupNonUniformFMulOp
, /*Signed=*/false,
1944 /*NonUniform=*/true>,
1945 GroupReducePattern
<spirv::GroupNonUniformSMinOp
, /*Signed=*/true,
1946 /*NonUniform=*/true>,
1947 GroupReducePattern
<spirv::GroupNonUniformUMinOp
, /*Signed=*/false,
1948 /*NonUniform=*/true>,
1949 GroupReducePattern
<spirv::GroupNonUniformFMinOp
, /*Signed=*/false,
1950 /*NonUniform=*/true>,
1951 GroupReducePattern
<spirv::GroupNonUniformSMaxOp
, /*Signed=*/true,
1952 /*NonUniform=*/true>,
1953 GroupReducePattern
<spirv::GroupNonUniformUMaxOp
, /*Signed=*/false,
1954 /*NonUniform=*/true>,
1955 GroupReducePattern
<spirv::GroupNonUniformFMaxOp
, /*Signed=*/false,
1956 /*NonUniform=*/true>,
1957 GroupReducePattern
<spirv::GroupNonUniformBitwiseAndOp
, /*Signed=*/false,
1958 /*NonUniform=*/true>,
1959 GroupReducePattern
<spirv::GroupNonUniformBitwiseOrOp
, /*Signed=*/false,
1960 /*NonUniform=*/true>,
1961 GroupReducePattern
<spirv::GroupNonUniformBitwiseXorOp
, /*Signed=*/false,
1962 /*NonUniform=*/true>,
1963 GroupReducePattern
<spirv::GroupNonUniformLogicalAndOp
, /*Signed=*/false,
1964 /*NonUniform=*/true>,
1965 GroupReducePattern
<spirv::GroupNonUniformLogicalOrOp
, /*Signed=*/false,
1966 /*NonUniform=*/true>,
1967 GroupReducePattern
<spirv::GroupNonUniformLogicalXorOp
, /*Signed=*/false,
1968 /*NonUniform=*/true>>(patterns
.getContext(),
1971 patterns
.add
<GlobalVariablePattern
>(clientAPI
, patterns
.getContext(),
1975 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1976 const LLVMTypeConverter
&typeConverter
, RewritePatternSet
&patterns
) {
1977 patterns
.add
<FuncConversionPattern
>(patterns
.getContext(), typeConverter
);
1980 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1981 const LLVMTypeConverter
&typeConverter
, RewritePatternSet
&patterns
) {
1982 patterns
.add
<ModuleConversionPattern
>(patterns
.getContext(), typeConverter
);
1985 //===----------------------------------------------------------------------===//
1986 // Pre-conversion hooks
1987 //===----------------------------------------------------------------------===//
1989 /// Hook for descriptor set and binding number encoding.
1990 static constexpr StringRef kBinding
= "binding";
1991 static constexpr StringRef kDescriptorSet
= "descriptor_set";
1992 void mlir::encodeBindAttribute(ModuleOp module
) {
1993 auto spvModules
= module
.getOps
<spirv::ModuleOp
>();
1994 for (auto spvModule
: spvModules
) {
1995 spvModule
.walk([&](spirv::GlobalVariableOp op
) {
1996 IntegerAttr descriptorSet
=
1997 op
->getAttrOfType
<IntegerAttr
>(kDescriptorSet
);
1998 IntegerAttr binding
= op
->getAttrOfType
<IntegerAttr
>(kBinding
);
1999 // For every global variable in the module, get the ones with descriptor
2000 // set and binding numbers.
2001 if (descriptorSet
&& binding
) {
2002 // Encode these numbers into the variable's symbolic name. If the
2003 // SPIR-V module has a name, add it at the beginning.
2004 auto moduleAndName
=
2005 spvModule
.getName().has_value()
2006 ? spvModule
.getName()->str() + "_" + op
.getSymName().str()
2007 : op
.getSymName().str();
2009 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName
,
2010 std::to_string(descriptorSet
.getInt()),
2011 std::to_string(binding
.getInt()));
2012 auto nameAttr
= StringAttr::get(op
->getContext(), name
);
2014 // Replace all symbol uses and set the new symbol name. Finally, remove
2015 // descriptor set and binding attributes.
2016 if (failed(SymbolTable::replaceAllSymbolUses(op
, nameAttr
, spvModule
)))
2017 op
.emitError("unable to replace all symbol uses for ") << name
;
2018 SymbolTable::setSymbolName(op
, nameAttr
);
2019 op
->removeAttr(kDescriptorSet
);
2020 op
->removeAttr(kBinding
);