Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / MemRefToSPIRV / MapMemRefStorageClassPass.cpp
blob4cbc3dfdae223cecd26896ce7d73d6b13dcbcb5f
1 //===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to map numeric MemRef memory spaces to
10 // symbolic ones defined in the SPIR-V specification.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
16 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/BuiltinAttributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Visitors.h"
26 #include "mlir/Interfaces/FunctionInterfaces.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/Support/Debug.h"
30 #include <optional>
32 namespace mlir {
33 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34 #include "mlir/Conversion/Passes.h.inc"
35 } // namespace mlir
37 #define DEBUG_TYPE "mlir-map-memref-storage-class"
39 using namespace mlir;
41 //===----------------------------------------------------------------------===//
42 // Mappings
43 //===----------------------------------------------------------------------===//
45 /// Mapping between SPIR-V storage classes to memref memory spaces.
46 ///
47 /// Note: memref does not have a defined semantics for each memory space; it
48 /// depends on the context where it is used. There are no particular reasons
49 /// behind the number assignments; we try to follow NVVM conventions and largely
50 /// give common storage classes a smaller number.
51 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53 MAP_FN(spirv::StorageClass::Generic, 1) \
54 MAP_FN(spirv::StorageClass::Workgroup, 3) \
55 MAP_FN(spirv::StorageClass::Uniform, 4) \
56 MAP_FN(spirv::StorageClass::Private, 5) \
57 MAP_FN(spirv::StorageClass::Function, 6) \
58 MAP_FN(spirv::StorageClass::PushConstant, 7) \
59 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60 MAP_FN(spirv::StorageClass::Input, 9) \
61 MAP_FN(spirv::StorageClass::Output, 10) \
62 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
64 std::optional<spirv::StorageClass>
65 spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
66 // Handle null memory space attribute specially.
67 if (!memorySpaceAttr)
68 return spirv::StorageClass::StorageBuffer;
70 // Unknown dialect custom attributes are not supported by default.
71 // Downstream callers should plug in more specialized ones.
72 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
73 if (!intAttr)
74 return std::nullopt;
75 unsigned memorySpace = intAttr.getInt();
77 #define STORAGE_SPACE_MAP_FN(storage, space) \
78 case space: \
79 return storage;
81 switch (memorySpace) {
82 VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
83 default:
84 break;
86 return std::nullopt;
88 #undef STORAGE_SPACE_MAP_FN
91 std::optional<unsigned>
92 spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
93 #define STORAGE_SPACE_MAP_FN(storage, space) \
94 case storage: \
95 return space;
97 switch (storageClass) {
98 VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
99 default:
100 break;
102 return std::nullopt;
104 #undef STORAGE_SPACE_MAP_FN
107 #undef VULKAN_STORAGE_SPACE_MAP_LIST
109 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
110 MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
111 MAP_FN(spirv::StorageClass::Generic, 1) \
112 MAP_FN(spirv::StorageClass::Workgroup, 3) \
113 MAP_FN(spirv::StorageClass::UniformConstant, 4) \
114 MAP_FN(spirv::StorageClass::Private, 5) \
115 MAP_FN(spirv::StorageClass::Function, 6) \
116 MAP_FN(spirv::StorageClass::Image, 7)
118 std::optional<spirv::StorageClass>
119 spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) {
120 // Handle null memory space attribute specially.
121 if (!memorySpaceAttr)
122 return spirv::StorageClass::CrossWorkgroup;
124 // Unknown dialect custom attributes are not supported by default.
125 // Downstream callers should plug in more specialized ones.
126 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
127 if (!intAttr)
128 return std::nullopt;
129 unsigned memorySpace = intAttr.getInt();
131 #define STORAGE_SPACE_MAP_FN(storage, space) \
132 case space: \
133 return storage;
135 switch (memorySpace) {
136 OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
137 default:
138 break;
140 return std::nullopt;
142 #undef STORAGE_SPACE_MAP_FN
145 std::optional<unsigned>
146 spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
147 #define STORAGE_SPACE_MAP_FN(storage, space) \
148 case storage: \
149 return space;
151 switch (storageClass) {
152 OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
153 default:
154 break;
156 return std::nullopt;
158 #undef STORAGE_SPACE_MAP_FN
161 #undef OPENCL_STORAGE_SPACE_MAP_LIST
163 //===----------------------------------------------------------------------===//
164 // Type Converter
165 //===----------------------------------------------------------------------===//
167 spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
168 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
169 : memorySpaceMap(memorySpaceMap) {
170 // Pass through for all other types.
171 addConversion([](Type type) { return type; });
173 addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> {
174 std::optional<spirv::StorageClass> storage =
175 this->memorySpaceMap(memRefType.getMemorySpace());
176 if (!storage) {
177 LLVM_DEBUG(llvm::dbgs()
178 << "cannot convert " << memRefType
179 << " due to being unable to find memory space in map\n");
180 return std::nullopt;
183 auto storageAttr =
184 spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
185 if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
186 return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
187 rankedType.getLayout(), storageAttr);
189 return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
192 addConversion([this](FunctionType type) {
193 auto inputs = llvm::map_to_vector(
194 type.getInputs(), [this](Type ty) { return convertType(ty); });
195 auto results = llvm::map_to_vector(
196 type.getResults(), [this](Type ty) { return convertType(ty); });
197 return FunctionType::get(type.getContext(), inputs, results);
201 //===----------------------------------------------------------------------===//
202 // Conversion Target
203 //===----------------------------------------------------------------------===//
205 /// Returns true if the given `type` is considered as legal for SPIR-V
206 /// conversion.
207 static bool isLegalType(Type type) {
208 if (auto memRefType = dyn_cast<BaseMemRefType>(type)) {
209 Attribute spaceAttr = memRefType.getMemorySpace();
210 return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
212 return true;
215 /// Returns true if the given `attr` is considered as legal for SPIR-V
216 /// conversion.
217 static bool isLegalAttr(Attribute attr) {
218 if (auto typeAttr = dyn_cast<TypeAttr>(attr))
219 return isLegalType(typeAttr.getValue());
220 return true;
223 /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
224 static bool isLegalOp(Operation *op) {
225 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
226 return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
227 llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
228 llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
229 isLegalType);
232 auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
233 return attr.getValue();
236 return llvm::all_of(op->getOperandTypes(), isLegalType) &&
237 llvm::all_of(op->getResultTypes(), isLegalType) &&
238 llvm::all_of(attrs, isLegalAttr);
241 std::unique_ptr<ConversionTarget>
242 spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
243 auto target = std::make_unique<ConversionTarget>(context);
244 target->markUnknownOpDynamicallyLegal(isLegalOp);
245 return target;
248 void spirv::convertMemRefTypesAndAttrs(
249 Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
250 AttrTypeReplacer replacer;
251 replacer.addReplacement([&typeConverter](BaseMemRefType origType)
252 -> std::optional<BaseMemRefType> {
253 return typeConverter.convertType<BaseMemRefType>(origType);
256 replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
257 /*replaceLocs=*/false,
258 /*replaceTypes=*/true);
261 //===----------------------------------------------------------------------===//
262 // Conversion Pass
263 //===----------------------------------------------------------------------===//
265 namespace {
266 class MapMemRefStorageClassPass final
267 : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
268 public:
269 MapMemRefStorageClassPass() = default;
271 explicit MapMemRefStorageClassPass(
272 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
273 : memorySpaceMap(memorySpaceMap) {}
275 LogicalResult initializeOptions(
276 StringRef options,
277 function_ref<LogicalResult(const Twine &)> errorHandler) override {
278 if (failed(Pass::initializeOptions(options, errorHandler)))
279 return failure();
281 if (clientAPI == "opencl")
282 memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
283 else if (clientAPI != "vulkan")
284 return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI);
286 return success();
289 void runOnOperation() override {
290 MLIRContext *context = &getContext();
291 Operation *op = getOperation();
293 spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
294 if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
295 spirv::TargetEnv targetEnv(attr);
296 if (targetEnv.allows(spirv::Capability::Kernel)) {
297 spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
298 } else if (targetEnv.allows(spirv::Capability::Shader)) {
299 spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
303 spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
304 // Perform the replacement.
305 spirv::convertMemRefTypesAndAttrs(op, converter);
307 // Check if there are any illegal ops remaining.
308 std::unique_ptr<ConversionTarget> target =
309 spirv::getMemorySpaceToStorageClassTarget(*context);
310 op->walk([&target, this](Operation *childOp) {
311 if (target->isIllegal(childOp)) {
312 childOp->emitOpError("failed to legalize memory space");
313 signalPassFailure();
314 return WalkResult::interrupt();
316 return WalkResult::advance();
320 private:
321 spirv::MemorySpaceToStorageClassMap memorySpaceMap =
322 spirv::mapMemorySpaceToVulkanStorageClass;
324 } // namespace
326 std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
327 return std::make_unique<MapMemRefStorageClassPass>();