Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / unittests / CodeGen / PassManagerTest.cpp
blob4d2c8b7bdb5f455420f9d4c66bd070f89caa80c1
1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/CGSCCPassManager.h"
10 #include "llvm/Analysis/LoopAnalysisManager.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/CodeGen/MachineFunction.h"
13 #include "llvm/CodeGen/MachineModuleInfo.h"
14 #include "llvm/CodeGen/MachinePassManager.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/MC/TargetRegistry.h"
18 #include "llvm/Passes/PassBuilder.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/TargetSelect.h"
21 #include "llvm/Target/TargetMachine.h"
22 #include "llvm/TargetParser/Host.h"
23 #include "llvm/TargetParser/Triple.h"
24 #include "gtest/gtest.h"
26 using namespace llvm;
28 namespace {
30 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
31 public:
32 struct Result {
33 Result(int Count) : InstructionCount(Count) {}
34 int InstructionCount;
37 /// Run the analysis pass over the function and return a result.
38 Result run(Function &F, FunctionAnalysisManager &AM) {
39 int Count = 0;
40 for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI)
41 for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE;
42 ++II)
43 ++Count;
44 return Result(Count);
47 private:
48 friend AnalysisInfoMixin<TestFunctionAnalysis>;
49 static AnalysisKey Key;
52 AnalysisKey TestFunctionAnalysis::Key;
54 class TestMachineFunctionAnalysis
55 : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
56 public:
57 struct Result {
58 Result(int Count) : InstructionCount(Count) {}
59 int InstructionCount;
62 /// Run the analysis pass over the machine function and return a result.
63 Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) {
64 auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM);
65 // Query function analysis result.
66 TestFunctionAnalysis::Result &FAR =
67 MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
68 // + 5
69 return FAR.InstructionCount;
72 private:
73 friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
74 static AnalysisKey Key;
77 AnalysisKey TestMachineFunctionAnalysis::Key;
79 const std::string DoInitErrMsg = "doInitialization failed";
80 const std::string DoFinalErrMsg = "doFinalization failed";
82 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
83 TestMachineFunctionPass(int &Count, std::vector<int> &BeforeInitialization,
84 std::vector<int> &BeforeFinalization,
85 std::vector<int> &MachineFunctionPassCount)
86 : Count(Count), BeforeInitialization(BeforeInitialization),
87 BeforeFinalization(BeforeFinalization),
88 MachineFunctionPassCount(MachineFunctionPassCount) {}
90 Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) {
91 // Force doInitialization fail by starting with big `Count`.
92 if (Count > 10000)
93 return make_error<StringError>(DoInitErrMsg, inconvertibleErrorCode());
95 // + 1
96 ++Count;
97 BeforeInitialization.push_back(Count);
98 return Error::success();
100 Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) {
101 // Force doFinalization fail by starting with big `Count`.
102 if (Count > 1000)
103 return make_error<StringError>(DoFinalErrMsg, inconvertibleErrorCode());
105 // + 1
106 ++Count;
107 BeforeFinalization.push_back(Count);
108 return Error::success();
111 PreservedAnalyses run(MachineFunction &MF,
112 MachineFunctionAnalysisManager &MFAM) {
113 // Query function analysis result.
114 TestFunctionAnalysis::Result &FAR =
115 MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
116 // 3 + 1 + 1 = 5
117 Count += FAR.InstructionCount;
119 // Query module analysis result.
120 MachineModuleInfo &MMI =
121 MFAM.getResult<MachineModuleAnalysis>(*MF.getFunction().getParent());
122 // 1 + 1 + 1 = 3
123 Count += (MMI.getModule() == MF.getFunction().getParent());
125 // Query machine function analysis result.
126 TestMachineFunctionAnalysis::Result &MFAR =
127 MFAM.getResult<TestMachineFunctionAnalysis>(MF);
128 // 3 + 1 + 1 = 5
129 Count += MFAR.InstructionCount;
131 MachineFunctionPassCount.push_back(Count);
133 return PreservedAnalyses::none();
136 int &Count;
137 std::vector<int> &BeforeInitialization;
138 std::vector<int> &BeforeFinalization;
139 std::vector<int> &MachineFunctionPassCount;
142 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
143 TestMachineModulePass(int &Count, std::vector<int> &MachineModulePassCount)
144 : Count(Count), MachineModulePassCount(MachineModulePassCount) {}
146 Error run(Module &M, MachineFunctionAnalysisManager &MFAM) {
147 MachineModuleInfo &MMI = MFAM.getResult<MachineModuleAnalysis>(M);
148 // + 1
149 Count += (MMI.getModule() == &M);
150 MachineModulePassCount.push_back(Count);
151 return Error::success();
154 PreservedAnalyses run(MachineFunction &MF,
155 MachineFunctionAnalysisManager &AM) {
156 llvm_unreachable(
157 "This should never be reached because this is machine module pass");
160 int &Count;
161 std::vector<int> &MachineModulePassCount;
164 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
165 SMDiagnostic Err;
166 return parseAssemblyString(IR, Err, Context);
169 class PassManagerTest : public ::testing::Test {
170 protected:
171 LLVMContext Context;
172 std::unique_ptr<Module> M;
173 std::unique_ptr<TargetMachine> TM;
175 public:
176 PassManagerTest()
177 : M(parseIR(Context, "define void @f() {\n"
178 "entry:\n"
179 " call void @g()\n"
180 " call void @h()\n"
181 " ret void\n"
182 "}\n"
183 "define void @g() {\n"
184 " ret void\n"
185 "}\n"
186 "define void @h() {\n"
187 " ret void\n"
188 "}\n")) {
189 // MachineModuleAnalysis needs a TargetMachine instance.
190 llvm::InitializeAllTargets();
192 std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
193 std::string Error;
194 const Target *TheTarget =
195 TargetRegistry::lookupTarget(TripleName, Error);
196 if (!TheTarget)
197 return;
199 TargetOptions Options;
200 TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options,
201 std::nullopt));
205 TEST_F(PassManagerTest, Basic) {
206 if (!TM)
207 GTEST_SKIP();
209 LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
210 M->setDataLayout(TM->createDataLayout());
212 LoopAnalysisManager LAM;
213 FunctionAnalysisManager FAM;
214 CGSCCAnalysisManager CGAM;
215 ModuleAnalysisManager MAM;
216 PassBuilder PB(TM.get());
217 PB.registerModuleAnalyses(MAM);
218 PB.registerFunctionAnalyses(FAM);
219 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
221 FAM.registerPass([&] { return TestFunctionAnalysis(); });
222 FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
223 MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); });
224 MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
226 MachineFunctionAnalysisManager MFAM;
228 // Test move assignment.
229 MachineFunctionAnalysisManager NestedMFAM(FAM, MAM);
230 NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); });
231 NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
232 MFAM = std::move(NestedMFAM);
235 int Count = 0;
236 std::vector<int> BeforeInitialization[2];
237 std::vector<int> BeforeFinalization[2];
238 std::vector<int> TestMachineFunctionCount[2];
239 std::vector<int> TestMachineModuleCount[2];
241 MachineFunctionPassManager MFPM;
243 // Test move assignment.
244 MachineFunctionPassManager NestedMFPM;
245 NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0]));
246 NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0],
247 BeforeFinalization[0],
248 TestMachineFunctionCount[0]));
249 NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1]));
250 NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
251 BeforeFinalization[1],
252 TestMachineFunctionCount[1]));
253 MFPM = std::move(NestedMFPM);
256 ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM)));
258 // Check first machine module pass
259 EXPECT_EQ(1u, TestMachineModuleCount[0].size());
260 EXPECT_EQ(3, TestMachineModuleCount[0][0]);
262 // Check first machine function pass
263 EXPECT_EQ(1u, BeforeInitialization[0].size());
264 EXPECT_EQ(1, BeforeInitialization[0][0]);
265 EXPECT_EQ(3u, TestMachineFunctionCount[0].size());
266 EXPECT_EQ(10, TestMachineFunctionCount[0][0]);
267 EXPECT_EQ(13, TestMachineFunctionCount[0][1]);
268 EXPECT_EQ(16, TestMachineFunctionCount[0][2]);
269 EXPECT_EQ(1u, BeforeFinalization[0].size());
270 EXPECT_EQ(31, BeforeFinalization[0][0]);
272 // Check second machine module pass
273 EXPECT_EQ(1u, TestMachineModuleCount[1].size());
274 EXPECT_EQ(17, TestMachineModuleCount[1][0]);
276 // Check second machine function pass
277 EXPECT_EQ(1u, BeforeInitialization[1].size());
278 EXPECT_EQ(2, BeforeInitialization[1][0]);
279 EXPECT_EQ(3u, TestMachineFunctionCount[1].size());
280 EXPECT_EQ(24, TestMachineFunctionCount[1][0]);
281 EXPECT_EQ(27, TestMachineFunctionCount[1][1]);
282 EXPECT_EQ(30, TestMachineFunctionCount[1][2]);
283 EXPECT_EQ(1u, BeforeFinalization[1].size());
284 EXPECT_EQ(32, BeforeFinalization[1][0]);
286 EXPECT_EQ(32, Count);
288 // doInitialization returns error
289 Count = 10000;
290 MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
291 BeforeFinalization[1],
292 TestMachineFunctionCount[1]));
293 std::string Message;
294 llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
295 Message = Error.getMessage();
297 EXPECT_EQ(Message, DoInitErrMsg);
299 // doFinalization returns error
300 Count = 1000;
301 MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
302 BeforeFinalization[1],
303 TestMachineFunctionCount[1]));
304 llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
305 Message = Error.getMessage();
307 EXPECT_EQ(Message, DoFinalErrMsg);
310 } // namespace