1 //===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
10 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
11 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
12 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
13 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
14 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
15 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
16 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
17 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
18 #include "mlir/Dialect/Arith/Transforms/Passes.h"
19 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
21 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
22 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #define DEBUG_TYPE "convert-to-spirv"
36 #define GEN_PASS_DEF_CONVERTTOSPIRVPASS
37 #include "mlir/Conversion/Passes.h.inc"
44 /// Map memRef memory space to SPIR-V storage class.
45 void mapToMemRef(Operation
*op
, spirv::TargetEnvAttr
&targetAttr
) {
46 spirv::TargetEnv
targetEnv(targetAttr
);
47 bool targetEnvSupportsKernelCapability
=
48 targetEnv
.allows(spirv::Capability::Kernel
);
49 spirv::MemorySpaceToStorageClassMap memorySpaceMap
=
50 targetEnvSupportsKernelCapability
51 ? spirv::mapMemorySpaceToOpenCLStorageClass
52 : spirv::mapMemorySpaceToVulkanStorageClass
;
53 spirv::MemorySpaceToStorageClassConverter
converter(memorySpaceMap
);
54 spirv::convertMemRefTypesAndAttrs(op
, converter
);
57 /// Populate patterns for each dialect.
58 void populateConvertToSPIRVPatterns(const SPIRVTypeConverter
&typeConverter
,
59 ScfToSPIRVContext
&scfToSPIRVContext
,
60 RewritePatternSet
&patterns
) {
61 arith::populateCeilFloorDivExpandOpsPatterns(patterns
);
62 arith::populateArithToSPIRVPatterns(typeConverter
, patterns
);
63 populateBuiltinFuncToSPIRVPatterns(typeConverter
, patterns
);
64 populateFuncToSPIRVPatterns(typeConverter
, patterns
);
65 populateGPUToSPIRVPatterns(typeConverter
, patterns
);
66 index::populateIndexToSPIRVPatterns(typeConverter
, patterns
);
67 populateMemRefToSPIRVPatterns(typeConverter
, patterns
);
68 populateVectorToSPIRVPatterns(typeConverter
, patterns
);
69 populateSCFToSPIRVPatterns(typeConverter
, scfToSPIRVContext
, patterns
);
70 ub::populateUBToSPIRVConversionPatterns(typeConverter
, patterns
);
73 /// A pass to perform the SPIR-V conversion.
74 struct ConvertToSPIRVPass final
75 : impl::ConvertToSPIRVPassBase
<ConvertToSPIRVPass
> {
76 using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase
;
78 void runOnOperation() override
{
79 Operation
*op
= getOperation();
80 MLIRContext
*context
= &getContext();
82 // Unroll vectors in function signatures to native size.
83 if (runSignatureConversion
&& failed(spirv::unrollVectorsInSignatures(op
)))
84 return signalPassFailure();
86 // Unroll vectors in function bodies to native size.
87 if (runVectorUnrolling
&& failed(spirv::unrollVectorsInFuncBodies(op
)))
88 return signalPassFailure();
90 // Generic conversion.
91 if (!convertGPUModules
) {
92 spirv::TargetEnvAttr targetAttr
= spirv::lookupTargetEnvOrDefault(op
);
93 std::unique_ptr
<ConversionTarget
> target
=
94 SPIRVConversionTarget::get(targetAttr
);
95 SPIRVTypeConverter
typeConverter(targetAttr
);
96 RewritePatternSet
patterns(context
);
97 ScfToSPIRVContext scfToSPIRVContext
;
98 mapToMemRef(op
, targetAttr
);
99 populateConvertToSPIRVPatterns(typeConverter
, scfToSPIRVContext
,
101 if (failed(applyPartialConversion(op
, *target
, std::move(patterns
))))
102 return signalPassFailure();
106 // Clone each GPU kernel module for conversion, given that the GPU
107 // launch op still needs the original GPU kernel module.
108 SmallVector
<Operation
*, 1> gpuModules
;
109 OpBuilder
builder(context
);
110 op
->walk([&](gpu::GPUModuleOp gpuModule
) {
111 builder
.setInsertionPoint(gpuModule
);
112 gpuModules
.push_back(builder
.clone(*gpuModule
));
114 // Run conversion for each module independently as they can have
115 // different TargetEnv attributes.
116 for (Operation
*gpuModule
: gpuModules
) {
117 spirv::TargetEnvAttr targetAttr
=
118 spirv::lookupTargetEnvOrDefault(gpuModule
);
119 std::unique_ptr
<ConversionTarget
> target
=
120 SPIRVConversionTarget::get(targetAttr
);
121 SPIRVTypeConverter
typeConverter(targetAttr
);
122 RewritePatternSet
patterns(context
);
123 ScfToSPIRVContext scfToSPIRVContext
;
124 mapToMemRef(gpuModule
, targetAttr
);
125 populateConvertToSPIRVPatterns(typeConverter
, scfToSPIRVContext
,
127 if (failed(applyFullConversion(gpuModule
, *target
, std::move(patterns
))))
128 return signalPassFailure();