[RISCV][NFC] precommit for D159399
[llvm-project.git] / mlir / unittests / Pass / PassManagerTest.cpp
blob9a30f64eaabc2930f94853b74a7a1db613b4a672
1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/Pass/Pass.h"
15 #include "gtest/gtest.h"
17 #include <memory>
19 using namespace mlir;
20 using namespace mlir::detail;
22 namespace {
23 /// Analysis that operates on any operation.
24 struct GenericAnalysis {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
27 GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(op)) {}
28 const bool isFunc;
31 /// Analysis that operates on a specific operation.
32 struct OpSpecificAnalysis {
33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
35 OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
36 const bool isSecret;
39 /// Simple pass to annotate a func::FuncOp with the results of analysis.
40 struct AnnotateFunctionPass
41 : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
42 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
44 void runOnOperation() override {
45 func::FuncOp op = getOperation();
46 Builder builder(op->getParentOfType<ModuleOp>());
48 auto &ga = getAnalysis<GenericAnalysis>();
49 auto &sa = getAnalysis<OpSpecificAnalysis>();
51 op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc));
52 op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret));
56 TEST(PassManagerTest, OpSpecificAnalysis) {
57 MLIRContext context;
58 context.loadDialect<func::FuncDialect>();
59 Builder builder(&context);
61 // Create a module with 2 functions.
62 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
63 for (StringRef name : {"secret", "not_secret"}) {
64 auto func = func::FuncOp::create(
65 builder.getUnknownLoc(), name,
66 builder.getFunctionType(std::nullopt, std::nullopt));
67 func.setPrivate();
68 module->push_back(func);
71 // Instantiate and run our pass.
72 auto pm = PassManager::on<ModuleOp>(&context);
73 pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
74 LogicalResult result = pm.run(module.get());
75 EXPECT_TRUE(succeeded(result));
77 // Verify that each function got annotated with expected attributes.
78 for (func::FuncOp func : module->getOps<func::FuncOp>()) {
79 ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isFunc")));
80 EXPECT_TRUE(cast<BoolAttr>(func->getDiscardableAttr("isFunc")).getValue());
82 bool isSecret = func.getName() == "secret";
83 ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isSecret")));
84 EXPECT_EQ(cast<BoolAttr>(func->getDiscardableAttr("isSecret")).getValue(),
85 isSecret);
89 namespace {
90 struct InvalidPass : Pass {
91 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
93 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
94 StringRef getName() const override { return "Invalid Pass"; }
95 void runOnOperation() override {}
96 bool canScheduleOn(RegisteredOperationName opName) const override {
97 return true;
100 /// A clone method to create a copy of this pass.
101 std::unique_ptr<Pass> clonePass() const override {
102 return std::make_unique<InvalidPass>(
103 *static_cast<const InvalidPass *>(this));
106 } // namespace
108 TEST(PassManagerTest, InvalidPass) {
109 MLIRContext context;
110 context.allowUnregisteredDialects();
112 // Create a module
113 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
115 // Add a single "invalid_op" operation
116 OpBuilder builder(&module->getBodyRegion());
117 OperationState state(UnknownLoc::get(&context), "invalid_op");
118 builder.insert(Operation::create(state));
120 // Register a diagnostic handler to capture the diagnostic so that we can
121 // check it later.
122 std::unique_ptr<Diagnostic> diagnostic;
123 context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
124 diagnostic = std::make_unique<Diagnostic>(std::move(diag));
127 // Instantiate and run our pass.
128 auto pm = PassManager::on<ModuleOp>(&context);
129 pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
130 LogicalResult result = pm.run(module.get());
131 EXPECT_TRUE(failed(result));
132 ASSERT_TRUE(diagnostic.get() != nullptr);
133 EXPECT_EQ(
134 diagnostic->str(),
135 "'invalid_op' op trying to schedule a pass on an unregistered operation");
137 // Check that clearing the pass manager effectively removed the pass.
138 pm.clear();
139 result = pm.run(module.get());
140 EXPECT_TRUE(succeeded(result));
142 // Check that adding the pass at the top-level triggers a fatal error.
143 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
144 "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
145 "PassManager intended to run on 'builtin.module', did you "
146 "intend to nest?");
149 /// Simple pass to annotate a func::FuncOp with the results of analysis.
150 struct InitializeCheckingPass
151 : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
152 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
153 LogicalResult initialize(MLIRContext *ctx) final {
154 initialized = true;
155 return success();
157 bool initialized = false;
159 void runOnOperation() override {
160 if (!initialized) {
161 getOperation()->emitError() << "Pass isn't initialized!";
162 signalPassFailure();
167 TEST(PassManagerTest, PassInitialization) {
168 MLIRContext context;
169 context.allowUnregisteredDialects();
171 // Create a module
172 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
174 // Instantiate and run our pass.
175 auto pm = PassManager::on<ModuleOp>(&context);
176 pm.addPass(std::make_unique<InitializeCheckingPass>());
177 EXPECT_TRUE(succeeded(pm.run(module.get())));
179 // Adding a second copy of the pass, we should also initialize it!
180 pm.addPass(std::make_unique<InitializeCheckingPass>());
181 EXPECT_TRUE(succeeded(pm.run(module.get())));
184 } // namespace