1 //===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/InlineCost.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/Analysis/InlineModelFeatureMaps.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/AsmParser/Parser.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
26 CallBase
*getCallInFunction(Function
*F
) {
27 for (auto &I
: instructions(F
)) {
28 if (auto *CB
= dyn_cast
<llvm::CallBase
>(&I
))
34 std::optional
<InlineCostFeatures
> getInliningCostFeaturesForCall(CallBase
&CB
) {
35 ModuleAnalysisManager MAM
;
36 FunctionAnalysisManager FAM
;
37 FAM
.registerPass([&] { return TargetIRAnalysis(); });
38 FAM
.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM
); });
39 FAM
.registerPass([&] { return AssumptionAnalysis(); });
40 MAM
.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM
); });
42 MAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
43 FAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
45 ModulePassManager MPM
;
46 MPM
.run(*CB
.getModule(), MAM
);
48 auto GetAssumptionCache
= [&](Function
&F
) -> AssumptionCache
& {
49 return FAM
.getResult
<AssumptionAnalysis
>(F
);
51 auto &TIR
= FAM
.getResult
<TargetIRAnalysis
>(*CB
.getFunction());
53 return getInliningCostFeatures(CB
, TIR
, GetAssumptionCache
);
56 // Tests that we can retrieve the CostFeatures without an error
57 TEST(InlineCostTest
, CostFeatures
) {
58 const auto *const IR
= R
"IR(
64 %2 = call i32 @f(i32 0)
71 std::unique_ptr
<Module
> M
= parseAssemblyString(IR
, Err
, C
);
74 auto *G
= M
->getFunction("g");
77 // find the call to f in g
78 CallBase
*CB
= getCallInFunction(G
);
81 const auto Features
= getInliningCostFeaturesForCall(*CB
);
83 // Check that the optional is not empty
84 ASSERT_TRUE(Features
);
87 // Tests the calculated SROA cost
88 TEST(InlineCostTest
, SROACost
) {
91 const auto *const IR
= R
"IR(
92 define void @f_savings(ptr %var) {
93 %load = load i32, ptr %var
94 %inc = add i32 %load, 1
95 store i32 %inc, ptr %var
99 define void @g_savings(i32) {
101 call void @f_savings(ptr %var)
105 define void @f_losses(ptr %var) {
106 %load = load i32, ptr %var
107 %inc = add i32 %load, 1
108 store i32 %inc, ptr %var
109 call void @prevent_sroa(ptr %var)
113 define void @g_losses(i32) {
115 call void @f_losses(ptr %var)
119 declare void @prevent_sroa(ptr)
124 std::unique_ptr
<Module
> M
= parseAssemblyString(IR
, Err
, C
);
127 const int DefaultInstCost
= 5;
128 const int DefaultAllocaCost
= 0;
130 const char *GName
[] = {"g_savings", "g_losses", nullptr};
131 const int Savings
[] = {2 * DefaultInstCost
+ DefaultAllocaCost
, 0};
132 const int Losses
[] = {0, 2 * DefaultInstCost
+ DefaultAllocaCost
};
134 for (unsigned i
= 0; GName
[i
]; ++i
) {
135 auto *G
= M
->getFunction(GName
[i
]);
138 // find the call to f in g
139 CallBase
*CB
= getCallInFunction(G
);
142 const auto Features
= getInliningCostFeaturesForCall(*CB
);
143 ASSERT_TRUE(Features
);
145 // Check the predicted SROA cost
146 auto GetFeature
= [&](InlineCostFeatureIndex I
) {
147 return (*Features
)[static_cast<size_t>(I
)];
149 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings
), Savings
[i
]);
150 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses
), Losses
[i
]);