1 //===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to map numeric MemRef memory spaces to
10 // symbolic ones defined in the SPIR-V specification.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
16 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/BuiltinAttributes.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/Visitors.h"
26 #include "mlir/Interfaces/FunctionInterfaces.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/Support/Debug.h"
33 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34 #include "mlir/Conversion/Passes.h.inc"
37 #define DEBUG_TYPE "mlir-map-memref-storage-class"
41 //===----------------------------------------------------------------------===//
43 //===----------------------------------------------------------------------===//
45 /// Mapping between SPIR-V storage classes to memref memory spaces.
47 /// Note: memref does not have a defined semantics for each memory space; it
48 /// depends on the context where it is used. There are no particular reasons
49 /// behind the number assignments; we try to follow NVVM conventions and largely
50 /// give common storage classes a smaller number.
51 #define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53 MAP_FN(spirv::StorageClass::Generic, 1) \
54 MAP_FN(spirv::StorageClass::Workgroup, 3) \
55 MAP_FN(spirv::StorageClass::Uniform, 4) \
56 MAP_FN(spirv::StorageClass::Private, 5) \
57 MAP_FN(spirv::StorageClass::Function, 6) \
58 MAP_FN(spirv::StorageClass::PushConstant, 7) \
59 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60 MAP_FN(spirv::StorageClass::Input, 9) \
61 MAP_FN(spirv::StorageClass::Output, 10) \
62 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
64 std::optional
<spirv::StorageClass
>
65 spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr
) {
66 // Handle null memory space attribute specially.
68 return spirv::StorageClass::StorageBuffer
;
70 // Unknown dialect custom attributes are not supported by default.
71 // Downstream callers should plug in more specialized ones.
72 auto intAttr
= dyn_cast
<IntegerAttr
>(memorySpaceAttr
);
75 unsigned memorySpace
= intAttr
.getInt();
77 #define STORAGE_SPACE_MAP_FN(storage, space) \
81 switch (memorySpace
) {
82 VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN
)
88 #undef STORAGE_SPACE_MAP_FN
91 std::optional
<unsigned>
92 spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass
) {
93 #define STORAGE_SPACE_MAP_FN(storage, space) \
97 switch (storageClass
) {
98 VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN
)
104 #undef STORAGE_SPACE_MAP_FN
107 #undef VULKAN_STORAGE_SPACE_MAP_LIST
109 #define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
110 MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
111 MAP_FN(spirv::StorageClass::Generic, 1) \
112 MAP_FN(spirv::StorageClass::Workgroup, 3) \
113 MAP_FN(spirv::StorageClass::UniformConstant, 4) \
114 MAP_FN(spirv::StorageClass::Private, 5) \
115 MAP_FN(spirv::StorageClass::Function, 6) \
116 MAP_FN(spirv::StorageClass::Image, 7)
118 std::optional
<spirv::StorageClass
>
119 spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr
) {
120 // Handle null memory space attribute specially.
121 if (!memorySpaceAttr
)
122 return spirv::StorageClass::CrossWorkgroup
;
124 // Unknown dialect custom attributes are not supported by default.
125 // Downstream callers should plug in more specialized ones.
126 auto intAttr
= dyn_cast
<IntegerAttr
>(memorySpaceAttr
);
129 unsigned memorySpace
= intAttr
.getInt();
131 #define STORAGE_SPACE_MAP_FN(storage, space) \
135 switch (memorySpace
) {
136 OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN
)
142 #undef STORAGE_SPACE_MAP_FN
145 std::optional
<unsigned>
146 spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass
) {
147 #define STORAGE_SPACE_MAP_FN(storage, space) \
151 switch (storageClass
) {
152 OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN
)
158 #undef STORAGE_SPACE_MAP_FN
161 #undef OPENCL_STORAGE_SPACE_MAP_LIST
163 //===----------------------------------------------------------------------===//
165 //===----------------------------------------------------------------------===//
167 spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
168 const spirv::MemorySpaceToStorageClassMap
&memorySpaceMap
)
169 : memorySpaceMap(memorySpaceMap
) {
170 // Pass through for all other types.
171 addConversion([](Type type
) { return type
; });
173 addConversion([this](BaseMemRefType memRefType
) -> std::optional
<Type
> {
174 std::optional
<spirv::StorageClass
> storage
=
175 this->memorySpaceMap(memRefType
.getMemorySpace());
177 LLVM_DEBUG(llvm::dbgs()
178 << "cannot convert " << memRefType
179 << " due to being unable to find memory space in map\n");
184 spirv::StorageClassAttr::get(memRefType
.getContext(), *storage
);
185 if (auto rankedType
= dyn_cast
<MemRefType
>(memRefType
)) {
186 return MemRefType::get(memRefType
.getShape(), memRefType
.getElementType(),
187 rankedType
.getLayout(), storageAttr
);
189 return UnrankedMemRefType::get(memRefType
.getElementType(), storageAttr
);
192 addConversion([this](FunctionType type
) {
193 auto inputs
= llvm::map_to_vector(
194 type
.getInputs(), [this](Type ty
) { return convertType(ty
); });
195 auto results
= llvm::map_to_vector(
196 type
.getResults(), [this](Type ty
) { return convertType(ty
); });
197 return FunctionType::get(type
.getContext(), inputs
, results
);
201 //===----------------------------------------------------------------------===//
203 //===----------------------------------------------------------------------===//
205 /// Returns true if the given `type` is considered as legal for SPIR-V
207 static bool isLegalType(Type type
) {
208 if (auto memRefType
= dyn_cast
<BaseMemRefType
>(type
)) {
209 Attribute spaceAttr
= memRefType
.getMemorySpace();
210 return isa_and_nonnull
<spirv::StorageClassAttr
>(spaceAttr
);
215 /// Returns true if the given `attr` is considered as legal for SPIR-V
217 static bool isLegalAttr(Attribute attr
) {
218 if (auto typeAttr
= dyn_cast
<TypeAttr
>(attr
))
219 return isLegalType(typeAttr
.getValue());
223 /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
224 static bool isLegalOp(Operation
*op
) {
225 if (auto funcOp
= dyn_cast
<FunctionOpInterface
>(op
)) {
226 return llvm::all_of(funcOp
.getArgumentTypes(), isLegalType
) &&
227 llvm::all_of(funcOp
.getResultTypes(), isLegalType
) &&
228 llvm::all_of(funcOp
.getFunctionBody().getArgumentTypes(),
232 auto attrs
= llvm::map_range(op
->getAttrs(), [](const NamedAttribute
&attr
) {
233 return attr
.getValue();
236 return llvm::all_of(op
->getOperandTypes(), isLegalType
) &&
237 llvm::all_of(op
->getResultTypes(), isLegalType
) &&
238 llvm::all_of(attrs
, isLegalAttr
);
241 std::unique_ptr
<ConversionTarget
>
242 spirv::getMemorySpaceToStorageClassTarget(MLIRContext
&context
) {
243 auto target
= std::make_unique
<ConversionTarget
>(context
);
244 target
->markUnknownOpDynamicallyLegal(isLegalOp
);
248 void spirv::convertMemRefTypesAndAttrs(
249 Operation
*op
, MemorySpaceToStorageClassConverter
&typeConverter
) {
250 AttrTypeReplacer replacer
;
251 replacer
.addReplacement([&typeConverter
](BaseMemRefType origType
)
252 -> std::optional
<BaseMemRefType
> {
253 return typeConverter
.convertType
<BaseMemRefType
>(origType
);
256 replacer
.recursivelyReplaceElementsIn(op
, /*replaceAttrs=*/true,
257 /*replaceLocs=*/false,
258 /*replaceTypes=*/true);
261 //===----------------------------------------------------------------------===//
263 //===----------------------------------------------------------------------===//
266 class MapMemRefStorageClassPass final
267 : public impl::MapMemRefStorageClassBase
<MapMemRefStorageClassPass
> {
269 MapMemRefStorageClassPass() = default;
271 explicit MapMemRefStorageClassPass(
272 const spirv::MemorySpaceToStorageClassMap
&memorySpaceMap
)
273 : memorySpaceMap(memorySpaceMap
) {}
275 LogicalResult
initializeOptions(
277 function_ref
<LogicalResult(const Twine
&)> errorHandler
) override
{
278 if (failed(Pass::initializeOptions(options
, errorHandler
)))
281 if (clientAPI
== "opencl")
282 memorySpaceMap
= spirv::mapMemorySpaceToOpenCLStorageClass
;
283 else if (clientAPI
!= "vulkan")
284 return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI
);
289 void runOnOperation() override
{
290 MLIRContext
*context
= &getContext();
291 Operation
*op
= getOperation();
293 spirv::MemorySpaceToStorageClassMap spaceToStorage
= memorySpaceMap
;
294 if (spirv::TargetEnvAttr attr
= spirv::lookupTargetEnv(op
)) {
295 spirv::TargetEnv
targetEnv(attr
);
296 if (targetEnv
.allows(spirv::Capability::Kernel
)) {
297 spaceToStorage
= spirv::mapMemorySpaceToOpenCLStorageClass
;
298 } else if (targetEnv
.allows(spirv::Capability::Shader
)) {
299 spaceToStorage
= spirv::mapMemorySpaceToVulkanStorageClass
;
303 spirv::MemorySpaceToStorageClassConverter
converter(spaceToStorage
);
304 // Perform the replacement.
305 spirv::convertMemRefTypesAndAttrs(op
, converter
);
307 // Check if there are any illegal ops remaining.
308 std::unique_ptr
<ConversionTarget
> target
=
309 spirv::getMemorySpaceToStorageClassTarget(*context
);
310 op
->walk([&target
, this](Operation
*childOp
) {
311 if (target
->isIllegal(childOp
)) {
312 childOp
->emitOpError("failed to legalize memory space");
314 return WalkResult::interrupt();
316 return WalkResult::advance();
321 spirv::MemorySpaceToStorageClassMap memorySpaceMap
=
322 spirv::mapMemorySpaceToVulkanStorageClass
;
326 std::unique_ptr
<OperationPass
<>> mlir::createMapMemRefStorageClassPass() {
327 return std::make_unique
<MapMemRefStorageClassPass
>();