1 //===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Pass/PassManager.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/Pass/Pass.h"
15 #include "gtest/gtest.h"
20 using namespace mlir::detail
;
23 /// Analysis that operates on any operation.
24 struct GenericAnalysis
{
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis
)
27 GenericAnalysis(Operation
*op
) : isFunc(isa
<func::FuncOp
>(op
)) {}
31 /// Analysis that operates on a specific operation.
32 struct OpSpecificAnalysis
{
33 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis
)
35 OpSpecificAnalysis(func::FuncOp op
) : isSecret(op
.getName() == "secret") {}
39 /// Simple pass to annotate a func::FuncOp with the results of analysis.
40 struct AnnotateFunctionPass
41 : public PassWrapper
<AnnotateFunctionPass
, OperationPass
<func::FuncOp
>> {
42 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass
)
44 void runOnOperation() override
{
45 func::FuncOp op
= getOperation();
46 Builder
builder(op
->getParentOfType
<ModuleOp
>());
48 auto &ga
= getAnalysis
<GenericAnalysis
>();
49 auto &sa
= getAnalysis
<OpSpecificAnalysis
>();
51 op
->setAttr("isFunc", builder
.getBoolAttr(ga
.isFunc
));
52 op
->setAttr("isSecret", builder
.getBoolAttr(sa
.isSecret
));
56 TEST(PassManagerTest
, OpSpecificAnalysis
) {
58 context
.loadDialect
<func::FuncDialect
>();
59 Builder
builder(&context
);
61 // Create a module with 2 functions.
62 OwningOpRef
<ModuleOp
> module(ModuleOp::create(UnknownLoc::get(&context
)));
63 for (StringRef name
: {"secret", "not_secret"}) {
64 auto func
= func::FuncOp::create(
65 builder
.getUnknownLoc(), name
,
66 builder
.getFunctionType(std::nullopt
, std::nullopt
));
68 module
->push_back(func
);
71 // Instantiate and run our pass.
72 auto pm
= PassManager::on
<ModuleOp
>(&context
);
73 pm
.addNestedPass
<func::FuncOp
>(std::make_unique
<AnnotateFunctionPass
>());
74 LogicalResult result
= pm
.run(module
.get());
75 EXPECT_TRUE(succeeded(result
));
77 // Verify that each function got annotated with expected attributes.
78 for (func::FuncOp func
: module
->getOps
<func::FuncOp
>()) {
79 ASSERT_TRUE(isa
<BoolAttr
>(func
->getDiscardableAttr("isFunc")));
80 EXPECT_TRUE(cast
<BoolAttr
>(func
->getDiscardableAttr("isFunc")).getValue());
82 bool isSecret
= func
.getName() == "secret";
83 ASSERT_TRUE(isa
<BoolAttr
>(func
->getDiscardableAttr("isSecret")));
84 EXPECT_EQ(cast
<BoolAttr
>(func
->getDiscardableAttr("isSecret")).getValue(),
90 struct InvalidPass
: Pass
{
91 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass
)
93 InvalidPass() : Pass(TypeID::get
<InvalidPass
>(), StringRef("invalid_op")) {}
94 StringRef
getName() const override
{ return "Invalid Pass"; }
95 void runOnOperation() override
{}
96 bool canScheduleOn(RegisteredOperationName opName
) const override
{
100 /// A clone method to create a copy of this pass.
101 std::unique_ptr
<Pass
> clonePass() const override
{
102 return std::make_unique
<InvalidPass
>(
103 *static_cast<const InvalidPass
*>(this));
108 TEST(PassManagerTest
, InvalidPass
) {
110 context
.allowUnregisteredDialects();
113 OwningOpRef
<ModuleOp
> module(ModuleOp::create(UnknownLoc::get(&context
)));
115 // Add a single "invalid_op" operation
116 OpBuilder
builder(&module
->getBodyRegion());
117 OperationState
state(UnknownLoc::get(&context
), "invalid_op");
118 builder
.insert(Operation::create(state
));
120 // Register a diagnostic handler to capture the diagnostic so that we can
122 std::unique_ptr
<Diagnostic
> diagnostic
;
123 context
.getDiagEngine().registerHandler([&](Diagnostic
&diag
) {
124 diagnostic
= std::make_unique
<Diagnostic
>(std::move(diag
));
127 // Instantiate and run our pass.
128 auto pm
= PassManager::on
<ModuleOp
>(&context
);
129 pm
.nest("invalid_op").addPass(std::make_unique
<InvalidPass
>());
130 LogicalResult result
= pm
.run(module
.get());
131 EXPECT_TRUE(failed(result
));
132 ASSERT_TRUE(diagnostic
.get() != nullptr);
135 "'invalid_op' op trying to schedule a pass on an unregistered operation");
137 // Check that clearing the pass manager effectively removed the pass.
139 result
= pm
.run(module
.get());
140 EXPECT_TRUE(succeeded(result
));
142 // Check that adding the pass at the top-level triggers a fatal error.
143 ASSERT_DEATH(pm
.addPass(std::make_unique
<InvalidPass
>()),
144 "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
145 "PassManager intended to run on 'builtin.module', did you "
149 /// Simple pass to annotate a func::FuncOp with the results of analysis.
150 struct InitializeCheckingPass
151 : public PassWrapper
<InitializeCheckingPass
, OperationPass
<ModuleOp
>> {
152 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass
)
153 LogicalResult
initialize(MLIRContext
*ctx
) final
{
157 bool initialized
= false;
159 void runOnOperation() override
{
161 getOperation()->emitError() << "Pass isn't initialized!";
167 TEST(PassManagerTest
, PassInitialization
) {
169 context
.allowUnregisteredDialects();
172 OwningOpRef
<ModuleOp
> module(ModuleOp::create(UnknownLoc::get(&context
)));
174 // Instantiate and run our pass.
175 auto pm
= PassManager::on
<ModuleOp
>(&context
);
176 pm
.addPass(std::make_unique
<InitializeCheckingPass
>());
177 EXPECT_TRUE(succeeded(pm
.run(module
.get())));
179 // Adding a second copy of the pass, we should also initialize it!
180 pm
.addPass(std::make_unique
<InitializeCheckingPass
>());
181 EXPECT_TRUE(succeeded(pm
.run(module
.get())));