1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
11 #include "mlir/Debug/ExecutionContext.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/Pass/Pass.h"
17 #include "gtest/gtest.h"
22 using namespace mlir::detail
;
25 /// Analysis that operates on any operation.
26 struct GenericAnalysis
{
27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis
)
29 GenericAnalysis(Operation
*op
) : isFunc(isa
<func::FuncOp
>(op
)) {}
33 /// Analysis that operates on a specific operation.
34 struct OpSpecificAnalysis
{
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis
)
37 OpSpecificAnalysis(func::FuncOp op
) : isSecret(op
.getName() == "secret") {}
41 /// Simple pass to annotate a func::FuncOp with the results of analysis.
42 struct AnnotateFunctionPass
43 : public PassWrapper
<AnnotateFunctionPass
, OperationPass
<func::FuncOp
>> {
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass
)
46 void runOnOperation() override
{
47 func::FuncOp op
= getOperation();
48 Builder
builder(op
->getParentOfType
<ModuleOp
>());
50 auto &ga
= getAnalysis
<GenericAnalysis
>();
51 auto &sa
= getAnalysis
<OpSpecificAnalysis
>();
53 op
->setAttr("isFunc", builder
.getBoolAttr(ga
.isFunc
));
54 op
->setAttr("isSecret", builder
.getBoolAttr(sa
.isSecret
));
58 TEST(PassManagerTest
, OpSpecificAnalysis
) {
60 context
.loadDialect
<func::FuncDialect
>();
61 Builder
builder(&context
);
63 // Create a module with 2 functions.
64 OwningOpRef
<ModuleOp
> module(ModuleOp::create(UnknownLoc::get(&context
)));
65 for (StringRef name
: {"secret", "not_secret"}) {
66 auto func
= func::FuncOp::create(
67 builder
.getUnknownLoc(), name
,
68 builder
.getFunctionType(std::nullopt
, std::nullopt
));
70 module
->push_back(func
);
73 // Instantiate and run our pass.
74 auto pm
= PassManager::on
<ModuleOp
>(&context
);
75 pm
.addNestedPass
<func::FuncOp
>(std::make_unique
<AnnotateFunctionPass
>());
76 LogicalResult result
= pm
.run(module
.get());
77 EXPECT_TRUE(succeeded(result
));
79 // Verify that each function got annotated with expected attributes.
80 for (func::FuncOp func
: module
->getOps
<func::FuncOp
>()) {
81 ASSERT_TRUE(isa
<BoolAttr
>(func
->getDiscardableAttr("isFunc")));
82 EXPECT_TRUE(cast
<BoolAttr
>(func
->getDiscardableAttr("isFunc")).getValue());
84 bool isSecret
= func
.getName() == "secret";
85 ASSERT_TRUE(isa
<BoolAttr
>(func
->getDiscardableAttr("isSecret")));
86 EXPECT_EQ(cast
<BoolAttr
>(func
->getDiscardableAttr("isSecret")).getValue(),
91 /// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
92 struct AddAttrFunctionPass
93 : public PassWrapper
<AddAttrFunctionPass
, OperationPass
<func::FuncOp
>> {
94 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass
)
96 void runOnOperation() override
{
97 func::FuncOp op
= getOperation();
98 Builder
builder(op
->getParentOfType
<ModuleOp
>());
99 if (op
->hasAttr("didProcess"))
100 op
->setAttr("didProcessAgain", builder
.getUnitAttr());
102 // We always want to set this one.
103 op
->setAttr("didProcess", builder
.getUnitAttr());
107 /// Simple pass to annotate a func::FuncOp with a single attribute
109 struct AddSecondAttrFunctionPass
110 : public PassWrapper
<AddSecondAttrFunctionPass
,
111 OperationPass
<func::FuncOp
>> {
112 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass
)
114 void runOnOperation() override
{
115 func::FuncOp op
= getOperation();
116 Builder
builder(op
->getParentOfType
<ModuleOp
>());
117 op
->setAttr("didProcess2", builder
.getUnitAttr());
121 TEST(PassManagerTest
, ExecutionAction
) {
123 context
.loadDialect
<func::FuncDialect
>();
124 Builder
builder(&context
);
126 // Create a module with 2 functions.
127 OwningOpRef
<ModuleOp
> module(ModuleOp::create(UnknownLoc::get(&context
)));
129 func::FuncOp::create(builder
.getUnknownLoc(), "process_me_once",
130 builder
.getFunctionType(std::nullopt
, std::nullopt
));
132 module
->push_back(f
);
134 // Instantiate our passes.
135 auto pm
= PassManager::on
<ModuleOp
>(&context
);
136 auto pass
= std::make_unique
<AddAttrFunctionPass
>();
137 auto *passPtr
= pass
.get();
138 pm
.addNestedPass
<func::FuncOp
>(std::move(pass
));
139 pm
.addNestedPass
<func::FuncOp
>(std::make_unique
<AddSecondAttrFunctionPass
>());
140 // Duplicate the first pass to ensure that we *only* run the *first* pass, not
141 // all instances of this pass kind. Notice that this pass (and the test as a
142 // whole) are built to ensure that we can run just a single pass out of a
143 // pipeline that may contain duplicates.
144 pm
.addNestedPass
<func::FuncOp
>(std::make_unique
<AddAttrFunctionPass
>());
146 // Use the action manager to only hit the first pass, not the second one.
147 auto onBreakpoint
= [&](const tracing::ActionActiveStack
*backtrace
)
148 -> tracing::ExecutionContext::Control
{
149 // Not a PassExecutionAction, apply the action.
150 auto *passExec
= dyn_cast
<PassExecutionAction
>(&backtrace
->getAction());
152 return tracing::ExecutionContext::Next
;
154 // If this isn't a function, apply the action.
155 if (!isa
<func::FuncOp
>(passExec
->getOp()))
156 return tracing::ExecutionContext::Next
;
158 // Only apply the first function pass. Not all instances of the first pass,
159 // only the first pass.
160 if (passExec
->getPass().getThreadingSiblingOrThis() == passPtr
)
161 return tracing::ExecutionContext::Next
;
163 // Do not apply any other passes in the pass manager.
164 return tracing::ExecutionContext::Skip
;
167 // Set up our breakpoint manager.
168 tracing::TagBreakpointManager simpleManager
;
169 tracing::ExecutionContext
executionCtx(onBreakpoint
);
170 executionCtx
.addBreakpointManager(&simpleManager
);
171 simpleManager
.addBreakpoint(PassExecutionAction::tag
);
173 // Register the execution context in the MLIRContext.
174 context
.registerActionHandler(executionCtx
);
176 // Run the pass manager, expecting our handler to be called.
177 LogicalResult result
= pm
.run(module
.get());
178 EXPECT_TRUE(succeeded(result
));
180 // Verify that each function got annotated with `didProcess` and *not*
182 for (func::FuncOp func
: module
->getOps
<func::FuncOp
>()) {
183 ASSERT_TRUE(func
->getDiscardableAttr("didProcess"));
184 ASSERT_FALSE(func
->getDiscardableAttr("didProcess2"));
185 ASSERT_FALSE(func
->getDiscardableAttr("didProcessAgain"));
190 struct InvalidPass
: Pass
{
191 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass
)
193 InvalidPass() : Pass(TypeID::get
<InvalidPass
>(), StringRef("invalid_op")) {}
194 StringRef
getName() const override
{ return "Invalid Pass"; }
195 void runOnOperation() override
{}
196 bool canScheduleOn(RegisteredOperationName opName
) const override
{
200 /// A clone method to create a copy of this pass.
201 std::unique_ptr
<Pass
> clonePass() const override
{
202 return std::make_unique
<InvalidPass
>(
203 *static_cast<const InvalidPass
*>(this));
208 TEST(PassManagerTest
, InvalidPass
) {
210 context
.allowUnregisteredDialects();
213 OwningOpRef
<ModuleOp
> module(ModuleOp::create(UnknownLoc::get(&context
)));
215 // Add a single "invalid_op" operation
216 OpBuilder
builder(&module
->getBodyRegion());
217 OperationState
state(UnknownLoc::get(&context
), "invalid_op");
218 builder
.insert(Operation::create(state
));
220 // Register a diagnostic handler to capture the diagnostic so that we can
222 std::unique_ptr
<Diagnostic
> diagnostic
;
223 context
.getDiagEngine().registerHandler([&](Diagnostic
&diag
) {
224 diagnostic
= std::make_unique
<Diagnostic
>(std::move(diag
));
227 // Instantiate and run our pass.
228 auto pm
= PassManager::on
<ModuleOp
>(&context
);
229 pm
.nest("invalid_op").addPass(std::make_unique
<InvalidPass
>());
230 LogicalResult result
= pm
.run(module
.get());
231 EXPECT_TRUE(failed(result
));
232 ASSERT_TRUE(diagnostic
.get() != nullptr);
235 "'invalid_op' op trying to schedule a pass on an unregistered operation");
237 // Check that clearing the pass manager effectively removed the pass.
239 result
= pm
.run(module
.get());
240 EXPECT_TRUE(succeeded(result
));
242 // Check that adding the pass at the top-level triggers a fatal error.
243 ASSERT_DEATH(pm
.addPass(std::make_unique
<InvalidPass
>()),
244 "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
245 "PassManager intended to run on 'builtin.module', did you "
249 /// Simple pass to annotate a func::FuncOp with the results of analysis.
250 struct InitializeCheckingPass
251 : public PassWrapper
<InitializeCheckingPass
, OperationPass
<ModuleOp
>> {
252 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass
)
253 LogicalResult
initialize(MLIRContext
*ctx
) final
{
257 bool initialized
= false;
259 void runOnOperation() override
{
261 getOperation()->emitError() << "Pass isn't initialized!";
267 TEST(PassManagerTest
, PassInitialization
) {
269 context
.allowUnregisteredDialects();
272 OwningOpRef
<ModuleOp
> module(ModuleOp::create(UnknownLoc::get(&context
)));
274 // Instantiate and run our pass.
275 auto pm
= PassManager::on
<ModuleOp
>(&context
);
276 pm
.addPass(std::make_unique
<InitializeCheckingPass
>());
277 EXPECT_TRUE(succeeded(pm
.run(module
.get())));
279 // Adding a second copy of the pass, we should also initialize it!
280 pm
.addPass(std::make_unique
<InitializeCheckingPass
>());
281 EXPECT_TRUE(succeeded(pm
.run(module
.get())));