1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/CGSCCPassManager.h"
10 #include "llvm/Analysis/LoopAnalysisManager.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/CodeGen/MachineFunction.h"
13 #include "llvm/CodeGen/MachineModuleInfo.h"
14 #include "llvm/CodeGen/MachinePassManager.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/MC/TargetRegistry.h"
18 #include "llvm/Passes/PassBuilder.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/TargetSelect.h"
21 #include "llvm/Target/TargetMachine.h"
22 #include "llvm/TargetParser/Host.h"
23 #include "llvm/TargetParser/Triple.h"
24 #include "gtest/gtest.h"
30 class TestFunctionAnalysis
: public AnalysisInfoMixin
<TestFunctionAnalysis
> {
33 Result(int Count
) : InstructionCount(Count
) {}
37 /// Run the analysis pass over the function and return a result.
38 Result
run(Function
&F
, FunctionAnalysisManager
&AM
) {
40 for (Function::iterator BBI
= F
.begin(), BBE
= F
.end(); BBI
!= BBE
; ++BBI
)
41 for (BasicBlock::iterator II
= BBI
->begin(), IE
= BBI
->end(); II
!= IE
;
48 friend AnalysisInfoMixin
<TestFunctionAnalysis
>;
49 static AnalysisKey Key
;
52 AnalysisKey
TestFunctionAnalysis::Key
;
54 class TestMachineFunctionAnalysis
55 : public AnalysisInfoMixin
<TestMachineFunctionAnalysis
> {
58 Result(int Count
) : InstructionCount(Count
) {}
62 /// Run the analysis pass over the machine function and return a result.
63 Result
run(MachineFunction
&MF
, MachineFunctionAnalysisManager::Base
&AM
) {
64 auto &MFAM
= static_cast<MachineFunctionAnalysisManager
&>(AM
);
65 // Query function analysis result.
66 TestFunctionAnalysis::Result
&FAR
=
67 MFAM
.getResult
<TestFunctionAnalysis
>(MF
.getFunction());
69 return FAR
.InstructionCount
;
73 friend AnalysisInfoMixin
<TestMachineFunctionAnalysis
>;
74 static AnalysisKey Key
;
77 AnalysisKey
TestMachineFunctionAnalysis::Key
;
79 const std::string DoInitErrMsg
= "doInitialization failed";
80 const std::string DoFinalErrMsg
= "doFinalization failed";
82 struct TestMachineFunctionPass
: public PassInfoMixin
<TestMachineFunctionPass
> {
83 TestMachineFunctionPass(int &Count
, std::vector
<int> &BeforeInitialization
,
84 std::vector
<int> &BeforeFinalization
,
85 std::vector
<int> &MachineFunctionPassCount
)
86 : Count(Count
), BeforeInitialization(BeforeInitialization
),
87 BeforeFinalization(BeforeFinalization
),
88 MachineFunctionPassCount(MachineFunctionPassCount
) {}
90 Error
doInitialization(Module
&M
, MachineFunctionAnalysisManager
&MFAM
) {
91 // Force doInitialization fail by starting with big `Count`.
93 return make_error
<StringError
>(DoInitErrMsg
, inconvertibleErrorCode());
97 BeforeInitialization
.push_back(Count
);
98 return Error::success();
100 Error
doFinalization(Module
&M
, MachineFunctionAnalysisManager
&MFAM
) {
101 // Force doFinalization fail by starting with big `Count`.
103 return make_error
<StringError
>(DoFinalErrMsg
, inconvertibleErrorCode());
107 BeforeFinalization
.push_back(Count
);
108 return Error::success();
111 PreservedAnalyses
run(MachineFunction
&MF
,
112 MachineFunctionAnalysisManager
&MFAM
) {
113 // Query function analysis result.
114 TestFunctionAnalysis::Result
&FAR
=
115 MFAM
.getResult
<TestFunctionAnalysis
>(MF
.getFunction());
117 Count
+= FAR
.InstructionCount
;
119 // Query module analysis result.
120 MachineModuleInfo
&MMI
=
121 MFAM
.getResult
<MachineModuleAnalysis
>(*MF
.getFunction().getParent());
123 Count
+= (MMI
.getModule() == MF
.getFunction().getParent());
125 // Query machine function analysis result.
126 TestMachineFunctionAnalysis::Result
&MFAR
=
127 MFAM
.getResult
<TestMachineFunctionAnalysis
>(MF
);
129 Count
+= MFAR
.InstructionCount
;
131 MachineFunctionPassCount
.push_back(Count
);
133 return PreservedAnalyses::none();
137 std::vector
<int> &BeforeInitialization
;
138 std::vector
<int> &BeforeFinalization
;
139 std::vector
<int> &MachineFunctionPassCount
;
142 struct TestMachineModulePass
: public PassInfoMixin
<TestMachineModulePass
> {
143 TestMachineModulePass(int &Count
, std::vector
<int> &MachineModulePassCount
)
144 : Count(Count
), MachineModulePassCount(MachineModulePassCount
) {}
146 Error
run(Module
&M
, MachineFunctionAnalysisManager
&MFAM
) {
147 MachineModuleInfo
&MMI
= MFAM
.getResult
<MachineModuleAnalysis
>(M
);
149 Count
+= (MMI
.getModule() == &M
);
150 MachineModulePassCount
.push_back(Count
);
151 return Error::success();
154 PreservedAnalyses
run(MachineFunction
&MF
,
155 MachineFunctionAnalysisManager
&AM
) {
157 "This should never be reached because this is machine module pass");
161 std::vector
<int> &MachineModulePassCount
;
164 std::unique_ptr
<Module
> parseIR(LLVMContext
&Context
, const char *IR
) {
166 return parseAssemblyString(IR
, Err
, Context
);
169 class PassManagerTest
: public ::testing::Test
{
172 std::unique_ptr
<Module
> M
;
173 std::unique_ptr
<TargetMachine
> TM
;
177 : M(parseIR(Context
, "define void @f() {\n"
183 "define void @g() {\n"
186 "define void @h() {\n"
189 // MachineModuleAnalysis needs a TargetMachine instance.
190 llvm::InitializeAllTargets();
192 std::string TripleName
= Triple::normalize(sys::getDefaultTargetTriple());
194 const Target
*TheTarget
=
195 TargetRegistry::lookupTarget(TripleName
, Error
);
199 TargetOptions Options
;
200 TM
.reset(TheTarget
->createTargetMachine(TripleName
, "", "", Options
,
205 TEST_F(PassManagerTest
, Basic
) {
209 LLVMTargetMachine
*LLVMTM
= static_cast<LLVMTargetMachine
*>(TM
.get());
210 M
->setDataLayout(TM
->createDataLayout());
212 LoopAnalysisManager LAM
;
213 FunctionAnalysisManager FAM
;
214 CGSCCAnalysisManager CGAM
;
215 ModuleAnalysisManager MAM
;
216 PassBuilder
PB(TM
.get());
217 PB
.registerModuleAnalyses(MAM
);
218 PB
.registerFunctionAnalyses(FAM
);
219 PB
.crossRegisterProxies(LAM
, FAM
, CGAM
, MAM
);
221 FAM
.registerPass([&] { return TestFunctionAnalysis(); });
222 FAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
223 MAM
.registerPass([&] { return MachineModuleAnalysis(LLVMTM
); });
224 MAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
226 MachineFunctionAnalysisManager MFAM
;
228 // Test move assignment.
229 MachineFunctionAnalysisManager
NestedMFAM(FAM
, MAM
);
230 NestedMFAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
231 NestedMFAM
.registerPass([&] { return TestMachineFunctionAnalysis(); });
232 MFAM
= std::move(NestedMFAM
);
236 std::vector
<int> BeforeInitialization
[2];
237 std::vector
<int> BeforeFinalization
[2];
238 std::vector
<int> TestMachineFunctionCount
[2];
239 std::vector
<int> TestMachineModuleCount
[2];
241 MachineFunctionPassManager MFPM
;
243 // Test move assignment.
244 MachineFunctionPassManager NestedMFPM
;
245 NestedMFPM
.addPass(TestMachineModulePass(Count
, TestMachineModuleCount
[0]));
246 NestedMFPM
.addPass(TestMachineFunctionPass(Count
, BeforeInitialization
[0],
247 BeforeFinalization
[0],
248 TestMachineFunctionCount
[0]));
249 NestedMFPM
.addPass(TestMachineModulePass(Count
, TestMachineModuleCount
[1]));
250 NestedMFPM
.addPass(TestMachineFunctionPass(Count
, BeforeInitialization
[1],
251 BeforeFinalization
[1],
252 TestMachineFunctionCount
[1]));
253 MFPM
= std::move(NestedMFPM
);
256 ASSERT_FALSE(errorToBool(MFPM
.run(*M
, MFAM
)));
258 // Check first machine module pass
259 EXPECT_EQ(1u, TestMachineModuleCount
[0].size());
260 EXPECT_EQ(3, TestMachineModuleCount
[0][0]);
262 // Check first machine function pass
263 EXPECT_EQ(1u, BeforeInitialization
[0].size());
264 EXPECT_EQ(1, BeforeInitialization
[0][0]);
265 EXPECT_EQ(3u, TestMachineFunctionCount
[0].size());
266 EXPECT_EQ(10, TestMachineFunctionCount
[0][0]);
267 EXPECT_EQ(13, TestMachineFunctionCount
[0][1]);
268 EXPECT_EQ(16, TestMachineFunctionCount
[0][2]);
269 EXPECT_EQ(1u, BeforeFinalization
[0].size());
270 EXPECT_EQ(31, BeforeFinalization
[0][0]);
272 // Check second machine module pass
273 EXPECT_EQ(1u, TestMachineModuleCount
[1].size());
274 EXPECT_EQ(17, TestMachineModuleCount
[1][0]);
276 // Check second machine function pass
277 EXPECT_EQ(1u, BeforeInitialization
[1].size());
278 EXPECT_EQ(2, BeforeInitialization
[1][0]);
279 EXPECT_EQ(3u, TestMachineFunctionCount
[1].size());
280 EXPECT_EQ(24, TestMachineFunctionCount
[1][0]);
281 EXPECT_EQ(27, TestMachineFunctionCount
[1][1]);
282 EXPECT_EQ(30, TestMachineFunctionCount
[1][2]);
283 EXPECT_EQ(1u, BeforeFinalization
[1].size());
284 EXPECT_EQ(32, BeforeFinalization
[1][0]);
286 EXPECT_EQ(32, Count
);
288 // doInitialization returns error
290 MFPM
.addPass(TestMachineFunctionPass(Count
, BeforeInitialization
[1],
291 BeforeFinalization
[1],
292 TestMachineFunctionCount
[1]));
294 llvm::handleAllErrors(MFPM
.run(*M
, MFAM
), [&](llvm::StringError
&Error
) {
295 Message
= Error
.getMessage();
297 EXPECT_EQ(Message
, DoInitErrMsg
);
299 // doFinalization returns error
301 MFPM
.addPass(TestMachineFunctionPass(Count
, BeforeInitialization
[1],
302 BeforeFinalization
[1],
303 TestMachineFunctionCount
[1]));
304 llvm::handleAllErrors(MFPM
.run(*M
, MFAM
), [&](llvm::StringError
&Error
) {
305 Message
= Error
.getMessage();
307 EXPECT_EQ(Message
, DoFinalErrMsg
);