[clangd] Fix warnings
[llvm-project.git] / llvm / unittests / Transforms / Instrumentation / PGOInstrumentationTest.cpp
blob9ccb13934cbd3811ff72df46a391531442c5361a
1 //===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/Passes/PassBuilder.h"
13 #include "llvm/ProfileData/InstrProf.h"
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
18 #include <tuple>
20 namespace {
22 using namespace llvm;
24 using testing::_;
25 using ::testing::DoDefault;
26 using ::testing::Invoke;
27 using ::testing::NotNull;
28 using ::testing::Ref;
29 using ::testing::Return;
30 using ::testing::Sequence;
31 using ::testing::Test;
32 using ::testing::TestParamInfo;
33 using ::testing::Values;
34 using ::testing::WithParamInterface;
36 template <typename Derived> class MockAnalysisHandleBase {
37 public:
38 class Analysis : public AnalysisInfoMixin<Analysis> {
39 public:
40 class Result {
41 public:
42 // Forward invalidation events to the mock handle.
43 bool invalidate(Module &M, const PreservedAnalyses &PA,
44 ModuleAnalysisManager::Invalidator &Inv) {
45 return Handle->invalidate(M, PA, Inv);
48 private:
49 explicit Result(Derived *Handle) : Handle(Handle) {}
51 friend MockAnalysisHandleBase;
52 Derived *Handle;
55 Result run(Module &M, ModuleAnalysisManager &AM) {
56 return Handle->run(M, AM);
59 private:
60 friend AnalysisInfoMixin<Analysis>;
61 friend MockAnalysisHandleBase;
62 static inline AnalysisKey Key;
64 Derived *Handle;
66 explicit Analysis(Derived *Handle) : Handle(Handle) {}
69 Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }
71 typename Analysis::Result getResult() {
72 return typename Analysis::Result(static_cast<Derived *>(this));
75 protected:
76 void setDefaults() {
77 ON_CALL(static_cast<Derived &>(*this), run(_, _))
78 .WillByDefault(Return(this->getResult()));
79 ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
80 .WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
81 ModuleAnalysisManager::Invalidator &) {
82 auto PAC = PA.template getChecker<Analysis>();
83 return !PAC.preserved() &&
84 !PAC.template preservedSet<AllAnalysesOn<Module>>();
85 }));
88 private:
89 friend Derived;
90 MockAnalysisHandleBase() = default;
93 class MockModuleAnalysisHandle
94 : public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
95 public:
96 MockModuleAnalysisHandle() { setDefaults(); }
98 MOCK_METHOD(typename Analysis::Result, run,
99 (Module &, ModuleAnalysisManager &));
101 MOCK_METHOD(bool, invalidate,
102 (Module &, const PreservedAnalyses &,
103 ModuleAnalysisManager::Invalidator &));
106 struct PGOInstrumentationGenTest
107 : public Test,
108 WithParamInterface<std::tuple<StringRef, StringRef>> {
109 ModulePassManager MPM;
110 PassBuilder PB;
111 MockModuleAnalysisHandle MMAHandle;
112 LoopAnalysisManager LAM;
113 FunctionAnalysisManager FAM;
114 CGSCCAnalysisManager CGAM;
115 ModuleAnalysisManager MAM;
116 LLVMContext Context;
117 std::unique_ptr<Module> M;
119 PGOInstrumentationGenTest() {
120 MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
121 PB.registerModuleAnalyses(MAM);
122 PB.registerCGSCCAnalyses(CGAM);
123 PB.registerFunctionAnalyses(FAM);
124 PB.registerLoopAnalyses(LAM);
125 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
126 MPM.addPass(
127 RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
128 MPM.addPass(PGOInstrumentationGen());
131 void parseAssembly(const StringRef IR) {
132 SMDiagnostic Error;
133 M = parseAssemblyString(IR, Error, Context);
134 std::string ErrMsg;
135 raw_string_ostream OS(ErrMsg);
136 Error.print("", OS);
138 // A failure here means that the test itself is buggy.
139 if (!M)
140 report_fatal_error(ErrMsg.c_str());
144 static constexpr StringRef CodeWithFuncDefs = R"(
145 define i32 @f(i32 %n) {
146 entry:
147 ret i32 0
148 })";
150 static constexpr StringRef CodeWithFuncDecls = R"(
151 declare i32 @f(i32);
154 static constexpr StringRef CodeWithGlobals = R"(
155 @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
156 declare i32 @f(i32);
159 INSTANTIATE_TEST_SUITE_P(
160 PGOInstrumetationGenTestSuite, PGOInstrumentationGenTest,
161 Values(std::make_tuple(CodeWithFuncDefs, "instrument_function_defs"),
162 std::make_tuple(CodeWithFuncDecls, "instrument_function_decls"),
163 std::make_tuple(CodeWithGlobals, "instrument_globals")),
164 [](const TestParamInfo<PGOInstrumentationGenTest::ParamType> &Info) {
165 return std::get<1>(Info.param).str();
168 TEST_P(PGOInstrumentationGenTest, Instrumented) {
169 const StringRef Code = std::get<0>(GetParam());
170 parseAssembly(Code);
172 ASSERT_THAT(M, NotNull());
174 Sequence PassSequence;
175 EXPECT_CALL(MMAHandle, run(Ref(*M), _))
176 .InSequence(PassSequence)
177 .WillOnce(DoDefault());
178 EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _))
179 .InSequence(PassSequence)
180 .WillOnce(DoDefault());
182 MPM.run(*M, MAM);
184 const auto *IRInstrVar =
185 M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
186 EXPECT_THAT(IRInstrVar, NotNull());
187 EXPECT_FALSE(IRInstrVar->isDeclaration());
190 } // end anonymous namespace