1 //===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/InlineCost.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/Analysis/InlineModelFeatureMaps.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/AsmParser/Parser.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "gtest/gtest.h"
25 CallBase
*getCallInFunction(Function
*F
) {
26 for (auto &I
: instructions(F
)) {
27 if (auto *CB
= dyn_cast
<llvm::CallBase
>(&I
))
33 std::optional
<InlineCostFeatures
> getInliningCostFeaturesForCall(CallBase
&CB
) {
34 ModuleAnalysisManager MAM
;
35 FunctionAnalysisManager FAM
;
36 FAM
.registerPass([&] { return TargetIRAnalysis(); });
37 FAM
.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM
); });
38 FAM
.registerPass([&] { return AssumptionAnalysis(); });
39 MAM
.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM
); });
41 MAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
42 FAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
44 ModulePassManager MPM
;
45 MPM
.run(*CB
.getModule(), MAM
);
47 auto GetAssumptionCache
= [&](Function
&F
) -> AssumptionCache
& {
48 return FAM
.getResult
<AssumptionAnalysis
>(F
);
50 auto &TIR
= FAM
.getResult
<TargetIRAnalysis
>(*CB
.getFunction());
52 return getInliningCostFeatures(CB
, TIR
, GetAssumptionCache
);
55 // Tests that we can retrieve the CostFeatures without an error
56 TEST(InlineCostTest
, CostFeatures
) {
57 const auto *const IR
= R
"IR(
63 %2 = call i32 @f(i32 0)
70 std::unique_ptr
<Module
> M
= parseAssemblyString(IR
, Err
, C
);
73 auto *G
= M
->getFunction("g");
76 // find the call to f in g
77 CallBase
*CB
= getCallInFunction(G
);
80 const auto Features
= getInliningCostFeaturesForCall(*CB
);
82 // Check that the optional is not empty
83 ASSERT_TRUE(Features
);
86 // Tests the calculated SROA cost
87 TEST(InlineCostTest
, SROACost
) {
90 const auto *const IR
= R
"IR(
91 define void @f_savings(ptr %var) {
92 %load = load i32, ptr %var
93 %inc = add i32 %load, 1
94 store i32 %inc, ptr %var
98 define void @g_savings(i32) {
100 call void @f_savings(ptr %var)
104 define void @f_losses(ptr %var) {
105 %load = load i32, ptr %var
106 %inc = add i32 %load, 1
107 store i32 %inc, ptr %var
108 call void @prevent_sroa(ptr %var)
112 define void @g_losses(i32) {
114 call void @f_losses(ptr %var)
118 declare void @prevent_sroa(ptr)
123 std::unique_ptr
<Module
> M
= parseAssemblyString(IR
, Err
, C
);
126 const int DefaultInstCost
= 5;
127 const int DefaultAllocaCost
= 0;
129 const char *GName
[] = {"g_savings", "g_losses", nullptr};
130 const int Savings
[] = {2 * DefaultInstCost
+ DefaultAllocaCost
, 0};
131 const int Losses
[] = {0, 2 * DefaultInstCost
+ DefaultAllocaCost
};
133 for (unsigned i
= 0; GName
[i
]; ++i
) {
134 auto *G
= M
->getFunction(GName
[i
]);
137 // find the call to f in g
138 CallBase
*CB
= getCallInFunction(G
);
141 const auto Features
= getInliningCostFeaturesForCall(*CB
);
142 ASSERT_TRUE(Features
);
144 // Check the predicted SROA cost
145 auto GetFeature
= [&](InlineCostFeatureIndex I
) {
146 return (*Features
)[static_cast<size_t>(I
)];
148 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings
), Savings
[i
]);
149 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses
), Losses
[i
]);