1 //===- TestModuleCombiner.cpp - Pass to test SPIR-V module combiner lib ---===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
11 #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Pass/Pass.h"
19 class TestModuleCombinerPass
20 : public PassWrapper
<TestModuleCombinerPass
,
21 OperationPass
<mlir::ModuleOp
>> {
23 TestModuleCombinerPass() = default;
24 TestModuleCombinerPass(const TestModuleCombinerPass
&) {}
25 void runOnOperation() override
;
28 mlir::spirv::OwningSPIRVModuleRef combinedModule
;
32 void TestModuleCombinerPass::runOnOperation() {
33 auto modules
= llvm::to_vector
<4>(getOperation().getOps
<spirv::ModuleOp
>());
35 OpBuilder
combinedModuleBuilder(modules
[0]);
36 combinedModule
= spirv::combine(modules
, combinedModuleBuilder
, nullptr);
38 for (spirv::ModuleOp module
: modules
)
43 void registerTestSpirvModuleCombinerPass() {
44 PassRegistration
<TestModuleCombinerPass
> registration(
45 "test-spirv-module-combiner", "Tests SPIR-V module combiner library");