1 //===- TestModuleCombiner.cpp - Pass to test SPIR-V module combiner lib ---===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
11 #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Pass/Pass.h"
19 class TestModuleCombinerPass
20 : public PassWrapper
<TestModuleCombinerPass
,
21 OperationPass
<mlir::ModuleOp
>> {
23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestModuleCombinerPass
)
25 StringRef
getArgument() const final
{ return "test-spirv-module-combiner"; }
26 StringRef
getDescription() const final
{
27 return "Tests SPIR-V module combiner library";
29 TestModuleCombinerPass() = default;
30 TestModuleCombinerPass(const TestModuleCombinerPass
&) {}
31 void runOnOperation() override
;
35 void TestModuleCombinerPass::runOnOperation() {
36 auto modules
= llvm::to_vector
<4>(getOperation().getOps
<spirv::ModuleOp
>());
40 OpBuilder
combinedModuleBuilder(modules
[0]);
42 auto listener
= [](spirv::ModuleOp originalModule
, StringRef oldSymbol
,
43 StringRef newSymbol
) {
44 llvm::outs() << "[" << originalModule
.getName() << "] " << oldSymbol
45 << " -> " << newSymbol
<< "\n";
48 OwningOpRef
<spirv::ModuleOp
> combinedModule
=
49 spirv::combine(modules
, combinedModuleBuilder
, listener
);
51 for (spirv::ModuleOp module
: modules
)
53 combinedModule
.release();
57 void registerTestSpirvModuleCombinerPass() {
58 PassRegistration
<TestModuleCombinerPass
>();