Bump version to 19.1.0-rc3
[llvm-project.git] / llvm / unittests / Transforms / Instrumentation / PGOInstrumentationTest.cpp
blob02c2df2a138b020f1b2f97704ef0914a6e31c2b8
1 //===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/Passes/PassBuilder.h"
13 #include "llvm/ProfileData/InstrProf.h"
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
18 #include <tuple>
20 namespace {
22 using namespace llvm;
24 using testing::_;
25 using ::testing::DoDefault;
26 using ::testing::Invoke;
27 using ::testing::IsNull;
28 using ::testing::NotNull;
29 using ::testing::Ref;
30 using ::testing::Return;
31 using ::testing::Sequence;
32 using ::testing::Test;
33 using ::testing::TestParamInfo;
34 using ::testing::Values;
35 using ::testing::WithParamInterface;
37 template <typename Derived> class MockAnalysisHandleBase {
38 public:
39 class Analysis : public AnalysisInfoMixin<Analysis> {
40 public:
41 class Result {
42 public:
43 // Forward invalidation events to the mock handle.
44 bool invalidate(Module &M, const PreservedAnalyses &PA,
45 ModuleAnalysisManager::Invalidator &Inv) {
46 return Handle->invalidate(M, PA, Inv);
49 private:
50 explicit Result(Derived *Handle) : Handle(Handle) {}
52 friend MockAnalysisHandleBase;
53 Derived *Handle;
56 Result run(Module &M, ModuleAnalysisManager &AM) {
57 return Handle->run(M, AM);
60 private:
61 friend AnalysisInfoMixin<Analysis>;
62 friend MockAnalysisHandleBase;
63 static inline AnalysisKey Key;
65 Derived *Handle;
67 explicit Analysis(Derived *Handle) : Handle(Handle) {}
70 Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }
72 typename Analysis::Result getResult() {
73 return typename Analysis::Result(static_cast<Derived *>(this));
76 protected:
77 void setDefaults() {
78 ON_CALL(static_cast<Derived &>(*this), run(_, _))
79 .WillByDefault(Return(this->getResult()));
80 ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
81 .WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
82 ModuleAnalysisManager::Invalidator &) {
83 auto PAC = PA.template getChecker<Analysis>();
84 return !PAC.preserved() &&
85 !PAC.template preservedSet<AllAnalysesOn<Module>>();
86 }));
89 private:
90 friend Derived;
91 MockAnalysisHandleBase() = default;
94 class MockModuleAnalysisHandle
95 : public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
96 public:
97 MockModuleAnalysisHandle() { setDefaults(); }
99 MOCK_METHOD(typename Analysis::Result, run,
100 (Module &, ModuleAnalysisManager &));
102 MOCK_METHOD(bool, invalidate,
103 (Module &, const PreservedAnalyses &,
104 ModuleAnalysisManager::Invalidator &));
107 struct PGOInstrumentationGenTest
108 : public Test,
109 WithParamInterface<std::tuple<StringRef, StringRef>> {
110 LLVMContext Ctx;
111 ModulePassManager MPM;
112 PassBuilder PB;
113 MockModuleAnalysisHandle MMAHandle;
114 LoopAnalysisManager LAM;
115 FunctionAnalysisManager FAM;
116 CGSCCAnalysisManager CGAM;
117 ModuleAnalysisManager MAM;
118 LLVMContext Context;
119 std::unique_ptr<Module> M;
121 PGOInstrumentationGenTest() {
122 MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
123 PB.registerModuleAnalyses(MAM);
124 PB.registerCGSCCAnalyses(CGAM);
125 PB.registerFunctionAnalyses(FAM);
126 PB.registerLoopAnalyses(LAM);
127 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
128 MPM.addPass(
129 RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
130 MPM.addPass(PGOInstrumentationGen());
133 void parseAssembly(const StringRef IR) {
134 SMDiagnostic Error;
135 M = parseAssemblyString(IR, Error, Context);
136 std::string ErrMsg;
137 raw_string_ostream OS(ErrMsg);
138 Error.print("", OS);
140 // A failure here means that the test itself is buggy.
141 if (!M)
142 report_fatal_error(OS.str().c_str());
146 static constexpr StringRef CodeWithFuncDefs = R"(
147 define i32 @f(i32 %n) {
148 entry:
149 ret i32 0
150 })";
152 static constexpr StringRef CodeWithFuncDecls = R"(
153 declare i32 @f(i32);
156 static constexpr StringRef CodeWithGlobals = R"(
157 @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
158 declare i32 @f(i32);
161 INSTANTIATE_TEST_SUITE_P(
162 PGOInstrumetationGenTestSuite, PGOInstrumentationGenTest,
163 Values(std::make_tuple(CodeWithFuncDefs, "instrument_function_defs"),
164 std::make_tuple(CodeWithFuncDecls, "instrument_function_decls"),
165 std::make_tuple(CodeWithGlobals, "instrument_globals")),
166 [](const TestParamInfo<PGOInstrumentationGenTest::ParamType> &Info) {
167 return std::get<1>(Info.param).str();
170 TEST_P(PGOInstrumentationGenTest, Instrumented) {
171 const StringRef Code = std::get<0>(GetParam());
172 parseAssembly(Code);
174 ASSERT_THAT(M, NotNull());
176 Sequence PassSequence;
177 EXPECT_CALL(MMAHandle, run(Ref(*M), _))
178 .InSequence(PassSequence)
179 .WillOnce(DoDefault());
180 EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _))
181 .InSequence(PassSequence)
182 .WillOnce(DoDefault());
184 MPM.run(*M, MAM);
186 const auto *IRInstrVar =
187 M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
188 EXPECT_THAT(IRInstrVar, NotNull());
189 EXPECT_FALSE(IRInstrVar->isDeclaration());
192 } // end anonymous namespace