1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/Pass/Pass.h"
17 //===----------------------------------------------------------------------===//
18 // Printing op availability pass
19 //===----------------------------------------------------------------------===//
22 /// A pass for testing SPIR-V op availability.
23 struct PrintOpAvailability
24 : public PassWrapper
<PrintOpAvailability
, OperationPass
<func::FuncOp
>> {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability
)
27 void runOnOperation() override
;
28 StringRef
getArgument() const final
{ return "test-spirv-op-availability"; }
29 StringRef
getDescription() const final
{
30 return "Test SPIR-V op availability";
35 void PrintOpAvailability::runOnOperation() {
36 auto f
= getOperation();
37 llvm::outs() << f
.getName() << "\n";
39 Dialect
*spirvDialect
= getContext().getLoadedDialect("spirv");
41 f
->walk([&](Operation
*op
) {
42 if (op
->getDialect() != spirvDialect
)
43 return WalkResult::advance();
45 auto opName
= op
->getName();
46 auto &os
= llvm::outs();
48 if (auto minVersionIfx
= dyn_cast
<spirv::QueryMinVersionInterface
>(op
)) {
49 std::optional
<spirv::Version
> minVersion
= minVersionIfx
.getMinVersion();
50 os
<< opName
<< " min version: ";
52 os
<< spirv::stringifyVersion(*minVersion
) << "\n";
57 if (auto maxVersionIfx
= dyn_cast
<spirv::QueryMaxVersionInterface
>(op
)) {
58 std::optional
<spirv::Version
> maxVersion
= maxVersionIfx
.getMaxVersion();
59 os
<< opName
<< " max version: ";
61 os
<< spirv::stringifyVersion(*maxVersion
) << "\n";
66 if (auto extension
= dyn_cast
<spirv::QueryExtensionInterface
>(op
)) {
67 os
<< opName
<< " extensions: [";
68 for (const auto &exts
: extension
.getExtensions()) {
70 llvm::interleaveComma(exts
, os
, [&](spirv::Extension ext
) {
71 os
<< spirv::stringifyExtension(ext
);
78 if (auto capability
= dyn_cast
<spirv::QueryCapabilityInterface
>(op
)) {
79 os
<< opName
<< " capabilities: [";
80 for (const auto &caps
: capability
.getCapabilities()) {
82 llvm::interleaveComma(caps
, os
, [&](spirv::Capability cap
) {
83 os
<< spirv::stringifyCapability(cap
);
91 return WalkResult::advance();
96 void registerPrintSpirvAvailabilityPass() {
97 PassRegistration
<PrintOpAvailability
>();
101 //===----------------------------------------------------------------------===//
102 // Converting target environment pass
103 //===----------------------------------------------------------------------===//
106 /// A pass for testing SPIR-V op availability.
107 struct ConvertToTargetEnv
108 : public PassWrapper
<ConvertToTargetEnv
, OperationPass
<func::FuncOp
>> {
109 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv
)
111 StringRef
getArgument() const override
{ return "test-spirv-target-env"; }
112 StringRef
getDescription() const override
{
113 return "Test SPIR-V target environment";
115 void runOnOperation() override
;
118 struct ConvertToAtomCmpExchangeWeak
: RewritePattern
{
119 ConvertToAtomCmpExchangeWeak(MLIRContext
*context
)
120 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
121 context
, {"spirv.AtomicCompareExchangeWeak"}) {}
123 LogicalResult
matchAndRewrite(Operation
*op
,
124 PatternRewriter
&rewriter
) const override
{
125 Value ptr
= op
->getOperand(0);
126 Value value
= op
->getOperand(1);
127 Value comparator
= op
->getOperand(2);
129 // Create a spirv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits
130 // in memory semantics to additionally require AtomicStorage capability.
131 rewriter
.replaceOpWithNewOp
<spirv::AtomicCompareExchangeWeakOp
>(
132 op
, value
.getType(), ptr
, spirv::Scope::Workgroup
,
133 spirv::MemorySemantics::AcquireRelease
|
134 spirv::MemorySemantics::AtomicCounterMemory
,
135 spirv::MemorySemantics::Acquire
, value
, comparator
);
140 struct ConvertToBitReverse
: RewritePattern
{
141 ConvertToBitReverse(MLIRContext
*context
)
142 : RewritePattern("test.convert_to_bit_reverse_op", 1, context
,
143 {"spirv.BitReverse"}) {}
145 LogicalResult
matchAndRewrite(Operation
*op
,
146 PatternRewriter
&rewriter
) const override
{
147 Value predicate
= op
->getOperand(0);
148 rewriter
.replaceOpWithNewOp
<spirv::BitReverseOp
>(
149 op
, op
->getResult(0).getType(), predicate
);
154 struct ConvertToGroupNonUniformBallot
: RewritePattern
{
155 ConvertToGroupNonUniformBallot(MLIRContext
*context
)
156 : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1,
157 context
, {"spirv.GroupNonUniformBallot"}) {}
159 LogicalResult
matchAndRewrite(Operation
*op
,
160 PatternRewriter
&rewriter
) const override
{
161 Value predicate
= op
->getOperand(0);
162 rewriter
.replaceOpWithNewOp
<spirv::GroupNonUniformBallotOp
>(
163 op
, op
->getResult(0).getType(), spirv::Scope::Workgroup
, predicate
);
168 struct ConvertToModule
: RewritePattern
{
169 ConvertToModule(MLIRContext
*context
)
170 : RewritePattern("test.convert_to_module_op", 1, context
,
173 LogicalResult
matchAndRewrite(Operation
*op
,
174 PatternRewriter
&rewriter
) const override
{
175 rewriter
.replaceOpWithNewOp
<spirv::ModuleOp
>(
176 op
, spirv::AddressingModel::PhysicalStorageBuffer64
,
177 spirv::MemoryModel::Vulkan
);
182 struct ConvertToSubgroupBallot
: RewritePattern
{
183 ConvertToSubgroupBallot(MLIRContext
*context
)
184 : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context
,
185 {"spirv.KHR.SubgroupBallot"}) {}
187 LogicalResult
matchAndRewrite(Operation
*op
,
188 PatternRewriter
&rewriter
) const override
{
189 Value predicate
= op
->getOperand(0);
190 rewriter
.replaceOpWithNewOp
<spirv::KHRSubgroupBallotOp
>(
191 op
, op
->getResult(0).getType(), predicate
);
196 template <const char *TestOpName
, typename SPIRVOp
>
197 struct ConvertToIntegerDotProd
: RewritePattern
{
198 ConvertToIntegerDotProd(MLIRContext
*context
)
199 : RewritePattern(TestOpName
, 1, context
, {SPIRVOp::getOperationName()}) {}
201 LogicalResult
matchAndRewrite(Operation
*op
,
202 PatternRewriter
&rewriter
) const override
{
203 rewriter
.replaceOpWithNewOp
<SPIRVOp
>(op
, op
->getResultTypes(),
204 op
->getOperands(), op
->getAttrs());
210 void ConvertToTargetEnv::runOnOperation() {
211 MLIRContext
*context
= &getContext();
212 func::FuncOp fn
= getOperation();
214 auto targetEnv
= dyn_cast_or_null
<spirv::TargetEnvAttr
>(
215 fn
.getOperation()->getDiscardableAttr(spirv::getTargetEnvAttrName()));
217 fn
.emitError("missing 'spirv.target_env' attribute");
218 return signalPassFailure();
221 auto target
= SPIRVConversionTarget::get(targetEnv
);
223 static constexpr char sDotTestOpName
[] = "test.convert_to_sdot_op";
224 static constexpr char suDotTestOpName
[] = "test.convert_to_sudot_op";
225 static constexpr char uDotTestOpName
[] = "test.convert_to_udot_op";
226 static constexpr char sDotAccSatTestOpName
[] =
227 "test.convert_to_sdot_acc_sat_op";
228 static constexpr char suDotAccSatTestOpName
[] =
229 "test.convert_to_sudot_acc_sat_op";
230 static constexpr char uDotAccSatTestOpName
[] =
231 "test.convert_to_udot_acc_sat_op";
233 RewritePatternSet
patterns(context
);
235 ConvertToAtomCmpExchangeWeak
, ConvertToBitReverse
,
236 ConvertToGroupNonUniformBallot
, ConvertToModule
, ConvertToSubgroupBallot
,
237 ConvertToIntegerDotProd
<sDotTestOpName
, spirv::SDotOp
>,
238 ConvertToIntegerDotProd
<suDotTestOpName
, spirv::SUDotOp
>,
239 ConvertToIntegerDotProd
<uDotTestOpName
, spirv::UDotOp
>,
240 ConvertToIntegerDotProd
<sDotAccSatTestOpName
, spirv::SDotAccSatOp
>,
241 ConvertToIntegerDotProd
<suDotAccSatTestOpName
, spirv::SUDotAccSatOp
>,
242 ConvertToIntegerDotProd
<uDotAccSatTestOpName
, spirv::UDotAccSatOp
>>(
245 if (failed(applyPartialConversion(fn
, *target
, std::move(patterns
))))
246 return signalPassFailure();
250 void registerConvertToTargetEnvPass() {
251 PassRegistration
<ConvertToTargetEnv
>();