[MemProf] Templatize CallStackRadixTreeBuilder (NFC) (#117014)
[llvm-project.git] / llvm / unittests / CodeGen / PassManagerTest.cpp
blob5e20c8db56fa8934fb6dd039c0ac619896fec2c2
1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Test that the various MachineFunction pass managers, adaptors, analyses, and
9 // analysis managers work.
10 //===----------------------------------------------------------------------===//
12 #include "llvm/IR/PassManager.h"
13 #include "llvm/Analysis/CGSCCPassManager.h"
14 #include "llvm/Analysis/LoopAnalysisManager.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/CodeGen/MachineFunction.h"
17 #include "llvm/CodeGen/MachineModuleInfo.h"
18 #include "llvm/CodeGen/MachinePassManager.h"
19 #include "llvm/IR/Analysis.h"
20 #include "llvm/IR/LLVMContext.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/MC/TargetRegistry.h"
23 #include "llvm/Passes/PassBuilder.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/Support/TargetSelect.h"
26 #include "llvm/Target/TargetMachine.h"
27 #include "llvm/TargetParser/Host.h"
28 #include "llvm/TargetParser/Triple.h"
29 #include "gtest/gtest.h"
31 using namespace llvm;
33 namespace {
35 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
36 public:
37 struct Result {
38 Result(int Count) : InstructionCount(Count) {}
39 int InstructionCount;
42 /// The number of instructions in the Function.
43 Result run(Function &F, FunctionAnalysisManager &AM) {
44 return Result(F.getInstructionCount());
47 private:
48 friend AnalysisInfoMixin<TestFunctionAnalysis>;
49 static AnalysisKey Key;
52 AnalysisKey TestFunctionAnalysis::Key;
54 class TestMachineFunctionAnalysis
55 : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
56 public:
57 struct Result {
58 Result(int Count) : InstructionCount(Count) {}
59 int InstructionCount;
62 Result run(MachineFunction &MF, MachineFunctionAnalysisManager &AM) {
63 FunctionAnalysisManager &FAM =
64 AM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
65 .getManager();
66 TestFunctionAnalysis::Result &FAR =
67 FAM.getResult<TestFunctionAnalysis>(MF.getFunction());
68 return FAR.InstructionCount;
71 private:
72 friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
73 static AnalysisKey Key;
76 AnalysisKey TestMachineFunctionAnalysis::Key;
78 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
79 TestMachineFunctionPass(int &Count, std::vector<int> &Counts)
80 : Count(Count), Counts(Counts) {}
82 PreservedAnalyses run(MachineFunction &MF,
83 MachineFunctionAnalysisManager &MFAM) {
84 FunctionAnalysisManager &FAM =
85 MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
86 .getManager();
87 TestFunctionAnalysis::Result &FAR =
88 FAM.getResult<TestFunctionAnalysis>(MF.getFunction());
89 Count += FAR.InstructionCount;
91 TestMachineFunctionAnalysis::Result &MFAR =
92 MFAM.getResult<TestMachineFunctionAnalysis>(MF);
93 Count += MFAR.InstructionCount;
95 Counts.push_back(Count);
97 return PreservedAnalyses::none();
100 int &Count;
101 std::vector<int> &Counts;
104 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
105 TestMachineModulePass(int &Count, std::vector<int> &Counts)
106 : Count(Count), Counts(Counts) {}
108 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) {
109 MachineModuleInfo &MMI = MAM.getResult<MachineModuleAnalysis>(M).getMMI();
110 FunctionAnalysisManager &FAM =
111 MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
112 MachineFunctionAnalysisManager &MFAM =
113 MAM.getResult<MachineFunctionAnalysisManagerModuleProxy>(M)
114 .getManager();
115 for (Function &F : M) {
116 MachineFunction &MF = MMI.getOrCreateMachineFunction(F);
117 Count += FAM.getResult<TestFunctionAnalysis>(F).InstructionCount;
118 Count += MFAM.getResult<TestMachineFunctionAnalysis>(MF).InstructionCount;
120 Counts.push_back(Count);
121 return PreservedAnalyses::all();
124 int &Count;
125 std::vector<int> &Counts;
128 struct ReportWarningPass : public PassInfoMixin<ReportWarningPass> {
129 PreservedAnalyses run(MachineFunction &MF,
130 MachineFunctionAnalysisManager &MFAM) {
131 auto &Ctx = MF.getContext();
132 Ctx.reportWarning(SMLoc(), "Test warning message.");
133 return PreservedAnalyses::all();
137 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
138 SMDiagnostic Err;
139 return parseAssemblyString(IR, Err, Context);
142 class PassManagerTest : public ::testing::Test {
143 protected:
144 LLVMContext Context;
145 std::unique_ptr<Module> M;
146 std::unique_ptr<TargetMachine> TM;
148 public:
149 PassManagerTest()
150 : M(parseIR(Context, "define void @f() {\n"
151 "entry:\n"
152 " call void @g()\n"
153 " call void @h()\n"
154 " ret void\n"
155 "}\n"
156 "define void @g() {\n"
157 " ret void\n"
158 "}\n"
159 "define void @h() {\n"
160 " ret void\n"
161 "}\n")) {
162 // MachineModuleAnalysis needs a TargetMachine instance.
163 llvm::InitializeAllTargets();
165 std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
166 std::string Error;
167 const Target *TheTarget =
168 TargetRegistry::lookupTarget(TripleName, Error);
169 if (!TheTarget)
170 return;
172 TargetOptions Options;
173 TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options,
174 std::nullopt));
178 TEST_F(PassManagerTest, Basic) {
179 if (!TM)
180 GTEST_SKIP();
182 M->setDataLayout(TM->createDataLayout());
184 MachineModuleInfo MMI(TM.get());
186 MachineFunctionAnalysisManager MFAM;
187 LoopAnalysisManager LAM;
188 FunctionAnalysisManager FAM;
189 CGSCCAnalysisManager CGAM;
190 ModuleAnalysisManager MAM;
191 PassBuilder PB(TM.get());
192 PB.registerModuleAnalyses(MAM);
193 PB.registerCGSCCAnalyses(CGAM);
194 PB.registerFunctionAnalyses(FAM);
195 PB.registerLoopAnalyses(LAM);
196 PB.registerMachineFunctionAnalyses(MFAM);
197 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM);
199 FAM.registerPass([&] { return TestFunctionAnalysis(); });
200 MAM.registerPass([&] { return MachineModuleAnalysis(MMI); });
201 MFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
203 int Count = 0;
204 std::vector<int> Counts;
206 ModulePassManager MPM;
207 FunctionPassManager FPM;
208 MachineFunctionPassManager MFPM;
209 MPM.addPass(TestMachineModulePass(Count, Counts));
210 FPM.addPass(createFunctionToMachineFunctionPassAdaptor(
211 TestMachineFunctionPass(Count, Counts)));
212 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
213 MPM.addPass(TestMachineModulePass(Count, Counts));
214 MFPM.addPass(TestMachineFunctionPass(Count, Counts));
215 FPM = FunctionPassManager();
216 FPM.addPass(createFunctionToMachineFunctionPassAdaptor(std::move(MFPM)));
217 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
219 testing::internal::CaptureStderr();
220 MPM.run(*M, MAM);
221 std::string Output = testing::internal::GetCapturedStderr();
223 EXPECT_EQ((std::vector<int>{10, 16, 18, 20, 30, 36, 38, 40}), Counts);
224 EXPECT_EQ(40, Count);
227 TEST_F(PassManagerTest, DiagnosticHandler) {
228 if (!TM)
229 GTEST_SKIP();
231 M->setDataLayout(TM->createDataLayout());
233 MachineModuleInfo MMI(TM.get());
235 LoopAnalysisManager LAM;
236 MachineFunctionAnalysisManager MFAM;
237 FunctionAnalysisManager FAM;
238 CGSCCAnalysisManager CGAM;
239 ModuleAnalysisManager MAM;
240 PassBuilder PB(TM.get());
241 PB.registerModuleAnalyses(MAM);
242 PB.registerCGSCCAnalyses(CGAM);
243 PB.registerFunctionAnalyses(FAM);
244 PB.registerLoopAnalyses(LAM);
245 PB.registerMachineFunctionAnalyses(MFAM);
246 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM);
248 MAM.registerPass([&] { return MachineModuleAnalysis(MMI); });
250 ModulePassManager MPM;
251 FunctionPassManager FPM;
252 MachineFunctionPassManager MFPM;
253 MPM.addPass(RequireAnalysisPass<MachineModuleAnalysis, Module>());
254 MFPM.addPass(ReportWarningPass());
255 FPM.addPass(createFunctionToMachineFunctionPassAdaptor(std::move(MFPM)));
256 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
257 testing::internal::CaptureStderr();
258 MPM.run(*M, MAM);
259 std::string Output = testing::internal::GetCapturedStderr();
261 EXPECT_TRUE(Output.find("warning: <unknown>:0: Test warning message.") !=
262 std::string::npos);
265 } // namespace