Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / ConvertToSPIRV / ConvertToSPIRVPass.cpp
blob4b7f7ff114deeb300d4011195dea0887aea85413
1 //===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
10 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
11 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
12 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
13 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
14 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
15 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
16 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
17 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
18 #include "mlir/Dialect/Arith/Transforms/Passes.h"
19 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
21 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
22 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include <memory>
33 #define DEBUG_TYPE "convert-to-spirv"
35 namespace mlir {
36 #define GEN_PASS_DEF_CONVERTTOSPIRVPASS
37 #include "mlir/Conversion/Passes.h.inc"
38 } // namespace mlir
40 using namespace mlir;
42 namespace {
44 /// Map memRef memory space to SPIR-V storage class.
45 void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
46 spirv::TargetEnv targetEnv(targetAttr);
47 bool targetEnvSupportsKernelCapability =
48 targetEnv.allows(spirv::Capability::Kernel);
49 spirv::MemorySpaceToStorageClassMap memorySpaceMap =
50 targetEnvSupportsKernelCapability
51 ? spirv::mapMemorySpaceToOpenCLStorageClass
52 : spirv::mapMemorySpaceToVulkanStorageClass;
53 spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
54 spirv::convertMemRefTypesAndAttrs(op, converter);
57 /// Populate patterns for each dialect.
58 void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
59 ScfToSPIRVContext &scfToSPIRVContext,
60 RewritePatternSet &patterns) {
61 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
62 arith::populateArithToSPIRVPatterns(typeConverter, patterns);
63 populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
64 populateFuncToSPIRVPatterns(typeConverter, patterns);
65 populateGPUToSPIRVPatterns(typeConverter, patterns);
66 index::populateIndexToSPIRVPatterns(typeConverter, patterns);
67 populateMemRefToSPIRVPatterns(typeConverter, patterns);
68 populateVectorToSPIRVPatterns(typeConverter, patterns);
69 populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
70 ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
73 /// A pass to perform the SPIR-V conversion.
74 struct ConvertToSPIRVPass final
75 : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
76 using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
78 void runOnOperation() override {
79 Operation *op = getOperation();
80 MLIRContext *context = &getContext();
82 // Unroll vectors in function signatures to native size.
83 if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op)))
84 return signalPassFailure();
86 // Unroll vectors in function bodies to native size.
87 if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
88 return signalPassFailure();
90 // Generic conversion.
91 if (!convertGPUModules) {
92 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
93 std::unique_ptr<ConversionTarget> target =
94 SPIRVConversionTarget::get(targetAttr);
95 SPIRVTypeConverter typeConverter(targetAttr);
96 RewritePatternSet patterns(context);
97 ScfToSPIRVContext scfToSPIRVContext;
98 mapToMemRef(op, targetAttr);
99 populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
100 patterns);
101 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
102 return signalPassFailure();
103 return;
106 // Clone each GPU kernel module for conversion, given that the GPU
107 // launch op still needs the original GPU kernel module.
108 SmallVector<Operation *, 1> gpuModules;
109 OpBuilder builder(context);
110 op->walk([&](gpu::GPUModuleOp gpuModule) {
111 builder.setInsertionPoint(gpuModule);
112 gpuModules.push_back(builder.clone(*gpuModule));
114 // Run conversion for each module independently as they can have
115 // different TargetEnv attributes.
116 for (Operation *gpuModule : gpuModules) {
117 spirv::TargetEnvAttr targetAttr =
118 spirv::lookupTargetEnvOrDefault(gpuModule);
119 std::unique_ptr<ConversionTarget> target =
120 SPIRVConversionTarget::get(targetAttr);
121 SPIRVTypeConverter typeConverter(targetAttr);
122 RewritePatternSet patterns(context);
123 ScfToSPIRVContext scfToSPIRVContext;
124 mapToMemRef(gpuModule, targetAttr);
125 populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
126 patterns);
127 if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
128 return signalPassFailure();
133 } // namespace