[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / unittests / Pass / PassManagerTest.cpp
blob7ceed3bb3bc3bd21fa025f13f977998295bc96fe
1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
11 #include "mlir/Debug/ExecutionContext.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/Pass/Pass.h"
17 #include "gtest/gtest.h"
19 #include <memory>
21 using namespace mlir;
22 using namespace mlir::detail;
24 namespace {
25 /// Analysis that operates on any operation.
26 struct GenericAnalysis {
27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
29 GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(op)) {}
30 const bool isFunc;
33 /// Analysis that operates on a specific operation.
34 struct OpSpecificAnalysis {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
37 OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
38 const bool isSecret;
41 /// Simple pass to annotate a func::FuncOp with the results of analysis.
42 struct AnnotateFunctionPass
43 : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
46 void runOnOperation() override {
47 func::FuncOp op = getOperation();
48 Builder builder(op->getParentOfType<ModuleOp>());
50 auto &ga = getAnalysis<GenericAnalysis>();
51 auto &sa = getAnalysis<OpSpecificAnalysis>();
53 op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
54 op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
58 TEST(PassManagerTest, OpSpecificAnalysis) {
59 MLIRContext context;
60 context.loadDialect<func::FuncDialect>();
61 Builder builder(&context);
63 // Create a module with 2 functions.
64 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
65 for (StringRef name : {"secret", "not_secret"}) {
66 auto func = func::FuncOp::create(
67 builder.getUnknownLoc(), name,
68 builder.getFunctionType(std::nullopt, std::nullopt));
69 func.setPrivate();
70 module->push_back(func);
73 // Instantiate and run our pass.
74 auto pm = PassManager::on<ModuleOp>(&context);
75 pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
76 LogicalResult result = pm.run(module.get());
77 EXPECT_TRUE(succeeded(result));
79 // Verify that each function got annotated with expected attributes.
80 for (func::FuncOp func : module->getOps<func::FuncOp>()) {
81 ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isFunc")));
82 EXPECT_TRUE(cast<BoolAttr>(func->getDiscardableAttr("isFunc")).getValue());
84 bool isSecret = func.getName() == "secret";
85 ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isSecret")));
86 EXPECT_EQ(cast<BoolAttr>(func->getDiscardableAttr("isSecret")).getValue(),
87 isSecret);
91 /// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
92 struct AddAttrFunctionPass
93 : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
94 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
96 void runOnOperation() override {
97 func::FuncOp op = getOperation();
98 Builder builder(op->getParentOfType<ModuleOp>());
99 if (op->hasAttr("didProcess"))
100 op->setAttr("didProcessAgain", builder.getUnitAttr());
102 // We always want to set this one.
103 op->setAttr("didProcess", builder.getUnitAttr());
107 /// Simple pass to annotate a func::FuncOp with a single attribute
108 /// `didProcess2`.
109 struct AddSecondAttrFunctionPass
110 : public PassWrapper<AddSecondAttrFunctionPass,
111 OperationPass<func::FuncOp>> {
112 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
114 void runOnOperation() override {
115 func::FuncOp op = getOperation();
116 Builder builder(op->getParentOfType<ModuleOp>());
117 op->setAttr("didProcess2", builder.getUnitAttr());
121 TEST(PassManagerTest, ExecutionAction) {
122 MLIRContext context;
123 context.loadDialect<func::FuncDialect>();
124 Builder builder(&context);
126 // Create a module with 2 functions.
127 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
128 auto f =
129 func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
130 builder.getFunctionType(std::nullopt, std::nullopt));
131 f.setPrivate();
132 module->push_back(f);
134 // Instantiate our passes.
135 auto pm = PassManager::on<ModuleOp>(&context);
136 auto pass = std::make_unique<AddAttrFunctionPass>();
137 auto *passPtr = pass.get();
138 pm.addNestedPass<func::FuncOp>(std::move(pass));
139 pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
140 // Duplicate the first pass to ensure that we *only* run the *first* pass, not
141 // all instances of this pass kind. Notice that this pass (and the test as a
142 // whole) are built to ensure that we can run just a single pass out of a
143 // pipeline that may contain duplicates.
144 pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
146 // Use the action manager to only hit the first pass, not the second one.
147 auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
148 -> tracing::ExecutionContext::Control {
149 // Not a PassExecutionAction, apply the action.
150 auto *passExec = dyn_cast<PassExecutionAction>(&backtrace->getAction());
151 if (!passExec)
152 return tracing::ExecutionContext::Next;
154 // If this isn't a function, apply the action.
155 if (!isa<func::FuncOp>(passExec->getOp()))
156 return tracing::ExecutionContext::Next;
158 // Only apply the first function pass. Not all instances of the first pass,
159 // only the first pass.
160 if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
161 return tracing::ExecutionContext::Next;
163 // Do not apply any other passes in the pass manager.
164 return tracing::ExecutionContext::Skip;
167 // Set up our breakpoint manager.
168 tracing::TagBreakpointManager simpleManager;
169 tracing::ExecutionContext executionCtx(onBreakpoint);
170 executionCtx.addBreakpointManager(&simpleManager);
171 simpleManager.addBreakpoint(PassExecutionAction::tag);
173 // Register the execution context in the MLIRContext.
174 context.registerActionHandler(executionCtx);
176 // Run the pass manager, expecting our handler to be called.
177 LogicalResult result = pm.run(module.get());
178 EXPECT_TRUE(succeeded(result));
180 // Verify that each function got annotated with `didProcess` and *not*
181 // `didProcess2`.
182 for (func::FuncOp func : module->getOps<func::FuncOp>()) {
183 ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
184 ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
185 ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
189 namespace {
190 struct InvalidPass : Pass {
191 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
193 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
194 StringRef getName() const override { return "Invalid Pass"; }
195 void runOnOperation() override {}
196 bool canScheduleOn(RegisteredOperationName opName) const override {
197 return true;
200 /// A clone method to create a copy of this pass.
201 std::unique_ptr<Pass> clonePass() const override {
202 return std::make_unique<InvalidPass>(
203 *static_cast<const InvalidPass *>(this));
206 } // namespace
208 TEST(PassManagerTest, InvalidPass) {
209 MLIRContext context;
210 context.allowUnregisteredDialects();
212 // Create a module
213 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
215 // Add a single "invalid_op" operation
216 OpBuilder builder(&module->getBodyRegion());
217 OperationState state(UnknownLoc::get(&context), "invalid_op");
218 builder.insert(Operation::create(state));
220 // Register a diagnostic handler to capture the diagnostic so that we can
221 // check it later.
222 std::unique_ptr<Diagnostic> diagnostic;
223 context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
224 diagnostic = std::make_unique<Diagnostic>(std::move(diag));
227 // Instantiate and run our pass.
228 auto pm = PassManager::on<ModuleOp>(&context);
229 pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
230 LogicalResult result = pm.run(module.get());
231 EXPECT_TRUE(failed(result));
232 ASSERT_TRUE(diagnostic.get() != nullptr);
233 EXPECT_EQ(
234 diagnostic->str(),
235 "'invalid_op' op trying to schedule a pass on an unregistered operation");
237 // Check that clearing the pass manager effectively removed the pass.
238 pm.clear();
239 result = pm.run(module.get());
240 EXPECT_TRUE(succeeded(result));
242 // Check that adding the pass at the top-level triggers a fatal error.
243 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
244 "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
245 "PassManager intended to run on 'builtin.module', did you "
246 "intend to nest?");
249 /// Simple pass to annotate a func::FuncOp with the results of analysis.
250 struct InitializeCheckingPass
251 : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
252 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
253 LogicalResult initialize(MLIRContext *ctx) final {
254 initialized = true;
255 return success();
257 bool initialized = false;
259 void runOnOperation() override {
260 if (!initialized) {
261 getOperation()->emitError() << "Pass isn't initialized!";
262 signalPassFailure();
267 TEST(PassManagerTest, PassInitialization) {
268 MLIRContext context;
269 context.allowUnregisteredDialects();
271 // Create a module
272 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
274 // Instantiate and run our pass.
275 auto pm = PassManager::on<ModuleOp>(&context);
276 pm.addPass(std::make_unique<InitializeCheckingPass>());
277 EXPECT_TRUE(succeeded(pm.run(module.get())));
279 // Adding a second copy of the pass, we should also initialize it!
280 pm.addPass(std::make_unique<InitializeCheckingPass>());
281 EXPECT_TRUE(succeeded(pm.run(module.get())));
284 } // namespace