Fix test failures introduced by PR #113697 (#116941)
[llvm-project.git] / llvm / unittests / Analysis / CtxProfAnalysisTest.cpp
blob3fba07ddd0f248ad9e3c413e9cfe2801d2e39544
1 //===--- CtxProfAnalysisTest.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/CtxProfAnalysis.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/CGSCCPassManager.h"
13 #include "llvm/Analysis/LoopAnalysisManager.h"
14 #include "llvm/AsmParser/Parser.h"
15 #include "llvm/IR/Analysis.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/PassInstrumentation.h"
18 #include "llvm/IR/PassManager.h"
19 #include "llvm/Passes/PassBuilder.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
25 using namespace llvm;
27 namespace {
29 class CtxProfAnalysisTest : public testing::Test {
30 static constexpr auto *IR = R"IR(
31 declare void @bar()
33 define private void @foo(i32 %a, ptr %fct) #0 !guid !0 {
34 %t = icmp eq i32 %a, 0
35 br i1 %t, label %yes, label %no
36 yes:
37 call void %fct(i32 %a)
38 br label %exit
39 no:
40 call void @bar()
41 br label %exit
42 exit:
43 ret void
46 define void @an_entrypoint(i32 %a) {
47 %t = icmp eq i32 %a, 0
48 br i1 %t, label %yes, label %no
50 yes:
51 call void @foo(i32 1, ptr null)
52 ret void
53 no:
54 ret void
57 define void @another_entrypoint_no_callees(i32 %a) {
58 %t = icmp eq i32 %a, 0
59 br i1 %t, label %yes, label %no
61 yes:
62 ret void
63 no:
64 ret void
67 define void @inlineasm() {
68 call void asm "nop", ""()
69 ret void
72 attributes #0 = { noinline }
73 !0 = !{ i64 11872291593386833696 }
74 )IR";
76 protected:
77 LLVMContext C;
78 PassBuilder PB;
79 ModuleAnalysisManager MAM;
80 FunctionAnalysisManager FAM;
81 CGSCCAnalysisManager CGAM;
82 LoopAnalysisManager LAM;
83 std::unique_ptr<Module> M;
85 void SetUp() override {
86 SMDiagnostic Err;
87 M = parseAssemblyString(IR, Err, C);
88 ASSERT_TRUE(!!M);
91 public:
92 CtxProfAnalysisTest() {
93 PB.registerModuleAnalyses(MAM);
94 PB.registerCGSCCAnalyses(CGAM);
95 PB.registerFunctionAnalyses(FAM);
96 PB.registerLoopAnalyses(LAM);
97 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
101 TEST_F(CtxProfAnalysisTest, GetCallsiteIDTest) {
102 ModulePassManager MPM;
103 MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
104 EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
105 auto *F = M->getFunction("foo");
106 ASSERT_NE(F, nullptr);
107 std::vector<uint32_t> InsValues;
109 for (auto &BB : *F)
110 for (auto &I : BB)
111 if (auto *CB = dyn_cast<CallBase>(&I)) {
112 // Skip instrumentation inserted intrinsics.
113 if (CB->getCalledFunction() && CB->getCalledFunction()->isIntrinsic())
114 continue;
115 auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB);
116 ASSERT_NE(Ins, nullptr);
117 InsValues.push_back(Ins->getIndex()->getZExtValue());
120 EXPECT_THAT(InsValues, testing::ElementsAre(0, 1));
123 TEST_F(CtxProfAnalysisTest, GetCallsiteIDInlineAsmTest) {
124 ModulePassManager MPM;
125 MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
126 EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
127 auto *F = M->getFunction("inlineasm");
128 ASSERT_NE(F, nullptr);
129 std::vector<const Instruction *> InsValues;
131 for (auto &BB : *F)
132 for (auto &I : BB)
133 if (auto *CB = dyn_cast<CallBase>(&I)) {
134 // Skip instrumentation inserted intrinsics.
135 if (CB->getCalledFunction() && CB->getCalledFunction()->isIntrinsic())
136 continue;
137 auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB);
138 InsValues.push_back(Ins);
141 EXPECT_THAT(InsValues, testing::ElementsAre(nullptr));
144 TEST_F(CtxProfAnalysisTest, GetCallsiteIDNegativeTest) {
145 auto *F = M->getFunction("foo");
146 ASSERT_NE(F, nullptr);
147 CallBase *FirstCall = nullptr;
148 for (auto &BB : *F)
149 for (auto &I : BB)
150 if (auto *CB = dyn_cast<CallBase>(&I)) {
151 if (CB->isIndirectCall() || !CB->getCalledFunction()->isIntrinsic()) {
152 FirstCall = CB;
153 break;
156 ASSERT_NE(FirstCall, nullptr);
157 auto *IndIns = CtxProfAnalysis::getCallsiteInstrumentation(*FirstCall);
158 EXPECT_EQ(IndIns, nullptr);
161 TEST_F(CtxProfAnalysisTest, GetBBIDTest) {
162 ModulePassManager MPM;
163 MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
164 EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
165 auto *F = M->getFunction("foo");
166 ASSERT_NE(F, nullptr);
167 std::map<std::string, int> BBNameAndID;
169 for (auto &BB : *F) {
170 auto *Ins = CtxProfAnalysis::getBBInstrumentation(BB);
171 if (Ins)
172 BBNameAndID[BB.getName().str()] =
173 static_cast<int>(Ins->getIndex()->getZExtValue());
174 else
175 BBNameAndID[BB.getName().str()] = -1;
178 EXPECT_THAT(BBNameAndID,
179 testing::UnorderedElementsAre(
180 testing::Pair("", 0), testing::Pair("yes", 1),
181 testing::Pair("no", -1), testing::Pair("exit", -1)));
183 } // namespace