[win/asan] GetInstructionSize: Make `83 EC XX` a generic entry. (#119537)
[llvm-project.git] / mlir / test / lib / Dialect / SPIRV / TestAvailability.cpp
blob2e5e591fe5f911bf9e58174a751ac1aae7c00432
1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/Pass/Pass.h"
15 using namespace mlir;
17 //===----------------------------------------------------------------------===//
18 // Printing op availability pass
19 //===----------------------------------------------------------------------===//
21 namespace {
22 /// A pass for testing SPIR-V op availability.
23 struct PrintOpAvailability
24 : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
27 void runOnOperation() override;
28 StringRef getArgument() const final { return "test-spirv-op-availability"; }
29 StringRef getDescription() const final {
30 return "Test SPIR-V op availability";
33 } // namespace
35 void PrintOpAvailability::runOnOperation() {
36 auto f = getOperation();
37 llvm::outs() << f.getName() << "\n";
39 Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
41 f->walk([&](Operation *op) {
42 if (op->getDialect() != spirvDialect)
43 return WalkResult::advance();
45 auto opName = op->getName();
46 auto &os = llvm::outs();
48 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
49 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
50 os << opName << " min version: ";
51 if (minVersion)
52 os << spirv::stringifyVersion(*minVersion) << "\n";
53 else
54 os << "None\n";
57 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
58 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
59 os << opName << " max version: ";
60 if (maxVersion)
61 os << spirv::stringifyVersion(*maxVersion) << "\n";
62 else
63 os << "None\n";
66 if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
67 os << opName << " extensions: [";
68 for (const auto &exts : extension.getExtensions()) {
69 os << " [";
70 llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
71 os << spirv::stringifyExtension(ext);
72 });
73 os << "]";
75 os << " ]\n";
78 if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
79 os << opName << " capabilities: [";
80 for (const auto &caps : capability.getCapabilities()) {
81 os << " [";
82 llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
83 os << spirv::stringifyCapability(cap);
84 });
85 os << "]";
87 os << " ]\n";
89 os.flush();
91 return WalkResult::advance();
92 });
95 namespace mlir {
96 void registerPrintSpirvAvailabilityPass() {
97 PassRegistration<PrintOpAvailability>();
99 } // namespace mlir
101 //===----------------------------------------------------------------------===//
102 // Converting target environment pass
103 //===----------------------------------------------------------------------===//
105 namespace {
106 /// A pass for testing SPIR-V op availability.
107 struct ConvertToTargetEnv
108 : public PassWrapper<ConvertToTargetEnv, OperationPass<func::FuncOp>> {
109 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv)
111 StringRef getArgument() const override { return "test-spirv-target-env"; }
112 StringRef getDescription() const override {
113 return "Test SPIR-V target environment";
115 void runOnOperation() override;
118 struct ConvertToAtomCmpExchangeWeak : RewritePattern {
119 ConvertToAtomCmpExchangeWeak(MLIRContext *context)
120 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
121 context, {"spirv.AtomicCompareExchangeWeak"}) {}
123 LogicalResult matchAndRewrite(Operation *op,
124 PatternRewriter &rewriter) const override {
125 Value ptr = op->getOperand(0);
126 Value value = op->getOperand(1);
127 Value comparator = op->getOperand(2);
129 // Create a spirv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits
130 // in memory semantics to additionally require AtomicStorage capability.
131 rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
132 op, value.getType(), ptr, spirv::Scope::Workgroup,
133 spirv::MemorySemantics::AcquireRelease |
134 spirv::MemorySemantics::AtomicCounterMemory,
135 spirv::MemorySemantics::Acquire, value, comparator);
136 return success();
140 struct ConvertToBitReverse : RewritePattern {
141 ConvertToBitReverse(MLIRContext *context)
142 : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
143 {"spirv.BitReverse"}) {}
145 LogicalResult matchAndRewrite(Operation *op,
146 PatternRewriter &rewriter) const override {
147 Value predicate = op->getOperand(0);
148 rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
149 op, op->getResult(0).getType(), predicate);
150 return success();
154 struct ConvertToGroupNonUniformBallot : RewritePattern {
155 ConvertToGroupNonUniformBallot(MLIRContext *context)
156 : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1,
157 context, {"spirv.GroupNonUniformBallot"}) {}
159 LogicalResult matchAndRewrite(Operation *op,
160 PatternRewriter &rewriter) const override {
161 Value predicate = op->getOperand(0);
162 rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
163 op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
164 return success();
168 struct ConvertToModule : RewritePattern {
169 ConvertToModule(MLIRContext *context)
170 : RewritePattern("test.convert_to_module_op", 1, context,
171 {"spirv.module"}) {}
173 LogicalResult matchAndRewrite(Operation *op,
174 PatternRewriter &rewriter) const override {
175 rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
176 op, spirv::AddressingModel::PhysicalStorageBuffer64,
177 spirv::MemoryModel::Vulkan);
178 return success();
182 struct ConvertToSubgroupBallot : RewritePattern {
183 ConvertToSubgroupBallot(MLIRContext *context)
184 : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
185 {"spirv.KHR.SubgroupBallot"}) {}
187 LogicalResult matchAndRewrite(Operation *op,
188 PatternRewriter &rewriter) const override {
189 Value predicate = op->getOperand(0);
190 rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
191 op, op->getResult(0).getType(), predicate);
192 return success();
196 template <const char *TestOpName, typename SPIRVOp>
197 struct ConvertToIntegerDotProd : RewritePattern {
198 ConvertToIntegerDotProd(MLIRContext *context)
199 : RewritePattern(TestOpName, 1, context, {SPIRVOp::getOperationName()}) {}
201 LogicalResult matchAndRewrite(Operation *op,
202 PatternRewriter &rewriter) const override {
203 rewriter.replaceOpWithNewOp<SPIRVOp>(op, op->getResultTypes(),
204 op->getOperands(), op->getAttrs());
205 return success();
208 } // namespace
210 void ConvertToTargetEnv::runOnOperation() {
211 MLIRContext *context = &getContext();
212 func::FuncOp fn = getOperation();
214 auto targetEnv = dyn_cast_or_null<spirv::TargetEnvAttr>(
215 fn.getOperation()->getDiscardableAttr(spirv::getTargetEnvAttrName()));
216 if (!targetEnv) {
217 fn.emitError("missing 'spirv.target_env' attribute");
218 return signalPassFailure();
221 auto target = SPIRVConversionTarget::get(targetEnv);
223 static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
224 static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
225 static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
226 static constexpr char sDotAccSatTestOpName[] =
227 "test.convert_to_sdot_acc_sat_op";
228 static constexpr char suDotAccSatTestOpName[] =
229 "test.convert_to_sudot_acc_sat_op";
230 static constexpr char uDotAccSatTestOpName[] =
231 "test.convert_to_udot_acc_sat_op";
233 RewritePatternSet patterns(context);
234 patterns.add<
235 ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
236 ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot,
237 ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
238 ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
239 ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>,
240 ConvertToIntegerDotProd<sDotAccSatTestOpName, spirv::SDotAccSatOp>,
241 ConvertToIntegerDotProd<suDotAccSatTestOpName, spirv::SUDotAccSatOp>,
242 ConvertToIntegerDotProd<uDotAccSatTestOpName, spirv::UDotAccSatOp>>(
243 context);
245 if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
246 return signalPassFailure();
249 namespace mlir {
250 void registerConvertToTargetEnvPass() {
251 PassRegistration<ConvertToTargetEnv>();
253 } // namespace mlir