1 //===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/Passes/PassBuilder.h"
13 #include "llvm/ProfileData/InstrProf.h"
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
25 using ::testing::DoDefault
;
26 using ::testing::Invoke
;
27 using ::testing::IsNull
;
28 using ::testing::NotNull
;
30 using ::testing::Return
;
31 using ::testing::Sequence
;
32 using ::testing::Test
;
33 using ::testing::TestParamInfo
;
34 using ::testing::Values
;
35 using ::testing::WithParamInterface
;
37 template <typename Derived
> class MockAnalysisHandleBase
{
39 class Analysis
: public AnalysisInfoMixin
<Analysis
> {
43 // Forward invalidation events to the mock handle.
44 bool invalidate(Module
&M
, const PreservedAnalyses
&PA
,
45 ModuleAnalysisManager::Invalidator
&Inv
) {
46 return Handle
->invalidate(M
, PA
, Inv
);
50 explicit Result(Derived
*Handle
) : Handle(Handle
) {}
52 friend MockAnalysisHandleBase
;
56 Result
run(Module
&M
, ModuleAnalysisManager
&AM
) {
57 return Handle
->run(M
, AM
);
61 friend AnalysisInfoMixin
<Analysis
>;
62 friend MockAnalysisHandleBase
;
63 static inline AnalysisKey Key
;
67 explicit Analysis(Derived
*Handle
) : Handle(Handle
) {}
70 Analysis
getAnalysis() { return Analysis(static_cast<Derived
*>(this)); }
72 typename
Analysis::Result
getResult() {
73 return typename
Analysis::Result(static_cast<Derived
*>(this));
78 ON_CALL(static_cast<Derived
&>(*this), run(_
, _
))
79 .WillByDefault(Return(this->getResult()));
80 ON_CALL(static_cast<Derived
&>(*this), invalidate(_
, _
, _
))
81 .WillByDefault(Invoke([](Module
&M
, const PreservedAnalyses
&PA
,
82 ModuleAnalysisManager::Invalidator
&) {
83 auto PAC
= PA
.template getChecker
<Analysis
>();
84 return !PAC
.preserved() &&
85 !PAC
.template preservedSet
<AllAnalysesOn
<Module
>>();
91 MockAnalysisHandleBase() = default;
94 class MockModuleAnalysisHandle
95 : public MockAnalysisHandleBase
<MockModuleAnalysisHandle
> {
97 MockModuleAnalysisHandle() { setDefaults(); }
99 MOCK_METHOD(typename
Analysis::Result
, run
,
100 (Module
&, ModuleAnalysisManager
&));
102 MOCK_METHOD(bool, invalidate
,
103 (Module
&, const PreservedAnalyses
&,
104 ModuleAnalysisManager::Invalidator
&));
107 struct PGOInstrumentationGenTest
109 WithParamInterface
<std::tuple
<StringRef
, StringRef
>> {
111 ModulePassManager MPM
;
113 MockModuleAnalysisHandle MMAHandle
;
114 LoopAnalysisManager LAM
;
115 FunctionAnalysisManager FAM
;
116 CGSCCAnalysisManager CGAM
;
117 ModuleAnalysisManager MAM
;
119 std::unique_ptr
<Module
> M
;
121 PGOInstrumentationGenTest() {
122 MAM
.registerPass([&] { return MMAHandle
.getAnalysis(); });
123 PB
.registerModuleAnalyses(MAM
);
124 PB
.registerCGSCCAnalyses(CGAM
);
125 PB
.registerFunctionAnalyses(FAM
);
126 PB
.registerLoopAnalyses(LAM
);
127 PB
.crossRegisterProxies(LAM
, FAM
, CGAM
, MAM
);
129 RequireAnalysisPass
<MockModuleAnalysisHandle::Analysis
, Module
>());
130 MPM
.addPass(PGOInstrumentationGen());
133 void parseAssembly(const StringRef IR
) {
135 M
= parseAssemblyString(IR
, Error
, Context
);
137 raw_string_ostream
OS(ErrMsg
);
140 // A failure here means that the test itself is buggy.
142 report_fatal_error(OS
.str().c_str());
146 static constexpr StringRef CodeWithFuncDefs
= R
"(
147 define i32 @f(i32 %n) {
152 static constexpr StringRef CodeWithFuncDecls
= R
"(
156 static constexpr StringRef CodeWithGlobals
= R
"(
157 @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
161 INSTANTIATE_TEST_SUITE_P(
162 PGOInstrumetationGenTestSuite
, PGOInstrumentationGenTest
,
163 Values(std::make_tuple(CodeWithFuncDefs
, "instrument_function_defs"),
164 std::make_tuple(CodeWithFuncDecls
, "instrument_function_decls"),
165 std::make_tuple(CodeWithGlobals
, "instrument_globals")),
166 [](const TestParamInfo
<PGOInstrumentationGenTest::ParamType
> &Info
) {
167 return std::get
<1>(Info
.param
).str();
170 TEST_P(PGOInstrumentationGenTest
, Instrumented
) {
171 const StringRef Code
= std::get
<0>(GetParam());
174 ASSERT_THAT(M
, NotNull());
176 Sequence PassSequence
;
177 EXPECT_CALL(MMAHandle
, run(Ref(*M
), _
))
178 .InSequence(PassSequence
)
179 .WillOnce(DoDefault());
180 EXPECT_CALL(MMAHandle
, invalidate(Ref(*M
), _
, _
))
181 .InSequence(PassSequence
)
182 .WillOnce(DoDefault());
186 const auto *IRInstrVar
=
187 M
->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR
));
188 EXPECT_THAT(IRInstrVar
, NotNull());
189 EXPECT_FALSE(IRInstrVar
->isDeclaration());
192 } // end anonymous namespace