Fix test failures introduced by PR #113697 (#116941)
[llvm-project.git] / llvm / unittests / Analysis / InlineCostTest.cpp
blob78e2aee95f82acadb08d4f0033eacc314cce2c79
1 //===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/InlineCost.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/Analysis/InlineModelFeatureMaps.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/AsmParser/Parser.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
22 namespace {
24 using namespace llvm;
26 CallBase *getCallInFunction(Function *F) {
27 for (auto &I : instructions(F)) {
28 if (auto *CB = dyn_cast<llvm::CallBase>(&I))
29 return CB;
31 return nullptr;
34 std::optional<InlineCostFeatures> getInliningCostFeaturesForCall(CallBase &CB) {
35 ModuleAnalysisManager MAM;
36 FunctionAnalysisManager FAM;
37 FAM.registerPass([&] { return TargetIRAnalysis(); });
38 FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
39 FAM.registerPass([&] { return AssumptionAnalysis(); });
40 MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
42 MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43 FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
45 ModulePassManager MPM;
46 MPM.run(*CB.getModule(), MAM);
48 auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
49 return FAM.getResult<AssumptionAnalysis>(F);
51 auto &TIR = FAM.getResult<TargetIRAnalysis>(*CB.getFunction());
53 return getInliningCostFeatures(CB, TIR, GetAssumptionCache);
56 // Tests that we can retrieve the CostFeatures without an error
57 TEST(InlineCostTest, CostFeatures) {
58 const auto *const IR = R"IR(
59 define i32 @f(i32) {
60 ret i32 4
63 define i32 @g(i32) {
64 %2 = call i32 @f(i32 0)
65 ret i32 %2
67 )IR";
69 LLVMContext C;
70 SMDiagnostic Err;
71 std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
72 ASSERT_TRUE(M);
74 auto *G = M->getFunction("g");
75 ASSERT_TRUE(G);
77 // find the call to f in g
78 CallBase *CB = getCallInFunction(G);
79 ASSERT_TRUE(CB);
81 const auto Features = getInliningCostFeaturesForCall(*CB);
83 // Check that the optional is not empty
84 ASSERT_TRUE(Features);
87 // Tests the calculated SROA cost
88 TEST(InlineCostTest, SROACost) {
89 using namespace llvm;
91 const auto *const IR = R"IR(
92 define void @f_savings(ptr %var) {
93 %load = load i32, ptr %var
94 %inc = add i32 %load, 1
95 store i32 %inc, ptr %var
96 ret void
99 define void @g_savings(i32) {
100 %var = alloca i32
101 call void @f_savings(ptr %var)
102 ret void
105 define void @f_losses(ptr %var) {
106 %load = load i32, ptr %var
107 %inc = add i32 %load, 1
108 store i32 %inc, ptr %var
109 call void @prevent_sroa(ptr %var)
110 ret void
113 define void @g_losses(i32) {
114 %var = alloca i32
115 call void @f_losses(ptr %var)
116 ret void
119 declare void @prevent_sroa(ptr)
120 )IR";
122 LLVMContext C;
123 SMDiagnostic Err;
124 std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
125 ASSERT_TRUE(M);
127 const int DefaultInstCost = 5;
128 const int DefaultAllocaCost = 0;
130 const char *GName[] = {"g_savings", "g_losses", nullptr};
131 const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0};
132 const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost};
134 for (unsigned i = 0; GName[i]; ++i) {
135 auto *G = M->getFunction(GName[i]);
136 ASSERT_TRUE(G);
138 // find the call to f in g
139 CallBase *CB = getCallInFunction(G);
140 ASSERT_TRUE(CB);
142 const auto Features = getInliningCostFeaturesForCall(*CB);
143 ASSERT_TRUE(Features);
145 // Check the predicted SROA cost
146 auto GetFeature = [&](InlineCostFeatureIndex I) {
147 return (*Features)[static_cast<size_t>(I)];
149 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]);
150 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]);
154 } // namespace