1 //===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/Passes/PassBuilder.h"
13 #include "llvm/ProfileData/InstrProf.h"
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
25 using ::testing::DoDefault
;
26 using ::testing::Invoke
;
27 using ::testing::NotNull
;
29 using ::testing::Return
;
30 using ::testing::Sequence
;
31 using ::testing::Test
;
32 using ::testing::TestParamInfo
;
33 using ::testing::Values
;
34 using ::testing::WithParamInterface
;
36 template <typename Derived
> class MockAnalysisHandleBase
{
38 class Analysis
: public AnalysisInfoMixin
<Analysis
> {
42 // Forward invalidation events to the mock handle.
43 bool invalidate(Module
&M
, const PreservedAnalyses
&PA
,
44 ModuleAnalysisManager::Invalidator
&Inv
) {
45 return Handle
->invalidate(M
, PA
, Inv
);
49 explicit Result(Derived
*Handle
) : Handle(Handle
) {}
51 friend MockAnalysisHandleBase
;
55 Result
run(Module
&M
, ModuleAnalysisManager
&AM
) {
56 return Handle
->run(M
, AM
);
60 friend AnalysisInfoMixin
<Analysis
>;
61 friend MockAnalysisHandleBase
;
62 static inline AnalysisKey Key
;
66 explicit Analysis(Derived
*Handle
) : Handle(Handle
) {}
69 Analysis
getAnalysis() { return Analysis(static_cast<Derived
*>(this)); }
71 typename
Analysis::Result
getResult() {
72 return typename
Analysis::Result(static_cast<Derived
*>(this));
77 ON_CALL(static_cast<Derived
&>(*this), run(_
, _
))
78 .WillByDefault(Return(this->getResult()));
79 ON_CALL(static_cast<Derived
&>(*this), invalidate(_
, _
, _
))
80 .WillByDefault(Invoke([](Module
&M
, const PreservedAnalyses
&PA
,
81 ModuleAnalysisManager::Invalidator
&) {
82 auto PAC
= PA
.template getChecker
<Analysis
>();
83 return !PAC
.preserved() &&
84 !PAC
.template preservedSet
<AllAnalysesOn
<Module
>>();
90 MockAnalysisHandleBase() = default;
93 class MockModuleAnalysisHandle
94 : public MockAnalysisHandleBase
<MockModuleAnalysisHandle
> {
96 MockModuleAnalysisHandle() { setDefaults(); }
98 MOCK_METHOD(typename
Analysis::Result
, run
,
99 (Module
&, ModuleAnalysisManager
&));
101 MOCK_METHOD(bool, invalidate
,
102 (Module
&, const PreservedAnalyses
&,
103 ModuleAnalysisManager::Invalidator
&));
106 struct PGOInstrumentationGenTest
108 WithParamInterface
<std::tuple
<StringRef
, StringRef
>> {
109 ModulePassManager MPM
;
111 MockModuleAnalysisHandle MMAHandle
;
112 LoopAnalysisManager LAM
;
113 FunctionAnalysisManager FAM
;
114 CGSCCAnalysisManager CGAM
;
115 ModuleAnalysisManager MAM
;
117 std::unique_ptr
<Module
> M
;
119 PGOInstrumentationGenTest() {
120 MAM
.registerPass([&] { return MMAHandle
.getAnalysis(); });
121 PB
.registerModuleAnalyses(MAM
);
122 PB
.registerCGSCCAnalyses(CGAM
);
123 PB
.registerFunctionAnalyses(FAM
);
124 PB
.registerLoopAnalyses(LAM
);
125 PB
.crossRegisterProxies(LAM
, FAM
, CGAM
, MAM
);
127 RequireAnalysisPass
<MockModuleAnalysisHandle::Analysis
, Module
>());
128 MPM
.addPass(PGOInstrumentationGen());
131 void parseAssembly(const StringRef IR
) {
133 M
= parseAssemblyString(IR
, Error
, Context
);
135 raw_string_ostream
OS(ErrMsg
);
138 // A failure here means that the test itself is buggy.
140 report_fatal_error(ErrMsg
.c_str());
144 static constexpr StringRef CodeWithFuncDefs
= R
"(
145 define i32 @f(i32 %n) {
150 static constexpr StringRef CodeWithFuncDecls
= R
"(
154 static constexpr StringRef CodeWithGlobals
= R
"(
155 @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
159 INSTANTIATE_TEST_SUITE_P(
160 PGOInstrumetationGenTestSuite
, PGOInstrumentationGenTest
,
161 Values(std::make_tuple(CodeWithFuncDefs
, "instrument_function_defs"),
162 std::make_tuple(CodeWithFuncDecls
, "instrument_function_decls"),
163 std::make_tuple(CodeWithGlobals
, "instrument_globals")),
164 [](const TestParamInfo
<PGOInstrumentationGenTest::ParamType
> &Info
) {
165 return std::get
<1>(Info
.param
).str();
168 TEST_P(PGOInstrumentationGenTest
, Instrumented
) {
169 const StringRef Code
= std::get
<0>(GetParam());
172 ASSERT_THAT(M
, NotNull());
174 Sequence PassSequence
;
175 EXPECT_CALL(MMAHandle
, run(Ref(*M
), _
))
176 .InSequence(PassSequence
)
177 .WillOnce(DoDefault());
178 EXPECT_CALL(MMAHandle
, invalidate(Ref(*M
), _
, _
))
179 .InSequence(PassSequence
)
180 .WillOnce(DoDefault());
184 const auto *IRInstrVar
=
185 M
->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR
));
186 EXPECT_THAT(IRInstrVar
, NotNull());
187 EXPECT_FALSE(IRInstrVar
->isDeclaration());
190 } // end anonymous namespace