1 //===- TestSPIRVVectorUnrolling.cpp - Test signature conversion -===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===-------------------------------------------------------------------===//
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 struct TestSPIRVVectorUnrolling final
23 : PassWrapper
<TestSPIRVVectorUnrolling
, OperationPass
<ModuleOp
>> {
24 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVVectorUnrolling
)
26 StringRef
getArgument() const final
{ return "test-spirv-vector-unrolling"; }
28 StringRef
getDescription() const final
{
29 return "Test patterns that unroll vectors to types supported by SPIR-V";
32 void getDependentDialects(DialectRegistry
®istry
) const override
{
33 registry
.insert
<spirv::SPIRVDialect
, vector::VectorDialect
>();
36 void runOnOperation() override
{
37 Operation
*op
= getOperation();
38 (void)spirv::unrollVectorsInSignatures(op
);
39 (void)spirv::unrollVectorsInFuncBodies(op
);
46 void registerTestSPIRVVectorUnrolling() {
47 PassRegistration
<TestSPIRVVectorUnrolling
>();