1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
26 static void removeSSACopy(Function
&F
) {
27 for (BasicBlock
&BB
: F
) {
28 for (Instruction
&Inst
: llvm::make_early_inc_range(BB
)) {
29 if (auto *II
= dyn_cast
<IntrinsicInst
>(&Inst
)) {
30 if (II
->getIntrinsicID() != Intrinsic::ssa_copy
)
32 Inst
.replaceAllUsesWith(II
->getOperand(0));
33 Inst
.eraseFromParent();
39 class FunctionSpecializationTest
: public testing::Test
{
42 FunctionAnalysisManager FAM
;
43 std::unique_ptr
<Module
> M
;
44 std::unique_ptr
<SCCPSolver
> Solver
;
46 FunctionSpecializationTest() {
47 FAM
.registerPass([&] { return TargetLibraryAnalysis(); });
48 FAM
.registerPass([&] { return TargetIRAnalysis(); });
49 FAM
.registerPass([&] { return BlockFrequencyAnalysis(); });
50 FAM
.registerPass([&] { return BranchProbabilityAnalysis(); });
51 FAM
.registerPass([&] { return LoopAnalysis(); });
52 FAM
.registerPass([&] { return AssumptionAnalysis(); });
53 FAM
.registerPass([&] { return DominatorTreeAnalysis(); });
54 FAM
.registerPass([&] { return PostDominatorTreeAnalysis(); });
55 FAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
58 Module
&parseModule(const char *ModuleString
) {
60 M
= parseAssemblyString(ModuleString
, Err
, Ctx
);
65 FunctionSpecializer
getSpecializerFor(Function
*F
) {
66 auto GetTLI
= [this](Function
&F
) -> const TargetLibraryInfo
& {
67 return FAM
.getResult
<TargetLibraryAnalysis
>(F
);
69 auto GetTTI
= [this](Function
&F
) -> TargetTransformInfo
& {
70 return FAM
.getResult
<TargetIRAnalysis
>(F
);
72 auto GetAC
= [this](Function
&F
) -> AssumptionCache
& {
73 return FAM
.getResult
<AssumptionAnalysis
>(F
);
75 auto GetDT
= [this](Function
&F
) -> DominatorTree
& {
76 return FAM
.getResult
<DominatorTreeAnalysis
>(F
);
78 auto GetBFI
= [this](Function
&F
) -> BlockFrequencyInfo
& {
79 return FAM
.getResult
<BlockFrequencyAnalysis
>(F
);
82 Solver
= std::make_unique
<SCCPSolver
>(M
->getDataLayout(), GetTLI
, Ctx
);
84 DominatorTree
&DT
= GetDT(*F
);
85 AssumptionCache
&AC
= GetAC(*F
);
86 Solver
->addPredicateInfo(*F
, DT
, AC
);
88 Solver
->markBlockExecutable(&F
->front());
89 for (Argument
&Arg
: F
->args())
90 Solver
->markOverdefined(&Arg
);
91 Solver
->solveWhileResolvedUndefsIn(*M
);
95 return FunctionSpecializer(*Solver
, *M
, &FAM
, GetBFI
, GetTLI
, GetTTI
,
99 Bonus
getInstCost(Instruction
&I
, bool SizeOnly
= false) {
100 auto &TTI
= FAM
.getResult
<TargetIRAnalysis
>(*I
.getFunction());
101 auto &BFI
= FAM
.getResult
<BlockFrequencyAnalysis
>(*I
.getFunction());
104 TTI
.getInstructionCost(&I
, TargetTransformInfo::TCK_CodeSize
);
109 : BFI
.getBlockFreq(I
.getParent()).getFrequency() /
110 BFI
.getEntryFreq().getFrequency() *
111 TTI
.getInstructionCost(&I
, TargetTransformInfo::TCK_Latency
);
113 return {CodeSize
, Latency
};
119 using namespace llvm
;
121 TEST_F(FunctionSpecializationTest
, SwitchInst
) {
122 const char *ModuleString
= R
"(
123 define void @foo(i32 %a, i32 %b, i32 %i) {
127 switch i32 %i, label %default
128 [ i32 1, label %case1
129 i32 2, label %case2 ]
149 Module
&M
= parseModule(ModuleString
);
150 Function
*F
= M
.getFunction("foo");
151 FunctionSpecializer Specializer
= getSpecializerFor(F
);
152 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
154 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
156 auto FuncIter
= F
->begin();
157 BasicBlock
&Loop
= *++FuncIter
;
158 BasicBlock
&Case1
= *++FuncIter
;
159 BasicBlock
&Case2
= *++FuncIter
;
160 BasicBlock
&BB1
= *++FuncIter
;
161 BasicBlock
&BB2
= *++FuncIter
;
163 Instruction
&Switch
= Loop
.front();
164 Instruction
&Mul
= Case1
.front();
165 Instruction
&And
= Case2
.front();
166 Instruction
&Sdiv
= *++Case2
.begin();
167 Instruction
&BrBB2
= Case2
.back();
168 Instruction
&Add
= BB1
.front();
169 Instruction
&Or
= BB2
.front();
170 Instruction
&BrLoop
= BB2
.back();
173 Bonus Ref
= getInstCost(Mul
);
174 Bonus Test
= Visitor
.getSpecializationBonus(F
->getArg(0), One
);
175 EXPECT_EQ(Test
, Ref
);
176 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
179 Ref
= getInstCost(And
) + getInstCost(Or
) + getInstCost(Add
);
180 Test
= Visitor
.getSpecializationBonus(F
->getArg(1), One
);
181 EXPECT_EQ(Test
, Ref
);
182 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
184 // switch + sdiv + br + br
185 Ref
= getInstCost(Switch
) +
186 getInstCost(Sdiv
, /*SizeOnly =*/ true) +
187 getInstCost(BrBB2
, /*SizeOnly =*/ true) +
188 getInstCost(BrLoop
, /*SizeOnly =*/ true);
189 Test
= Visitor
.getSpecializationBonus(F
->getArg(2), One
);
190 EXPECT_EQ(Test
, Ref
);
191 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
194 TEST_F(FunctionSpecializationTest
, BranchInst
) {
195 const char *ModuleString
= R
"(
196 define void @foo(i32 %a, i32 %b, i1 %cond) {
200 br i1 %cond, label %bb0, label %bb3
204 br i1 %cond, label %bb1, label %bb2
216 Module
&M
= parseModule(ModuleString
);
217 Function
*F
= M
.getFunction("foo");
218 FunctionSpecializer Specializer
= getSpecializerFor(F
);
219 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
221 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
222 Constant
*False
= ConstantInt::getFalse(M
.getContext());
224 auto FuncIter
= F
->begin();
225 BasicBlock
&Loop
= *++FuncIter
;
226 BasicBlock
&BB0
= *++FuncIter
;
227 BasicBlock
&BB1
= *++FuncIter
;
228 BasicBlock
&BB2
= *++FuncIter
;
230 Instruction
&Branch
= Loop
.front();
231 Instruction
&Mul
= BB0
.front();
232 Instruction
&Sub
= *++BB0
.begin();
233 Instruction
&BrBB1BB2
= BB0
.back();
234 Instruction
&Add
= BB1
.front();
235 Instruction
&Sdiv
= *++BB1
.begin();
236 Instruction
&BrBB2
= BB1
.back();
237 Instruction
&BrLoop
= BB2
.front();
240 Bonus Ref
= getInstCost(Mul
);
241 Bonus Test
= Visitor
.getSpecializationBonus(F
->getArg(0), One
);
242 EXPECT_EQ(Test
, Ref
);
243 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
246 Ref
= getInstCost(Add
);
247 Test
= Visitor
.getSpecializationBonus(F
->getArg(1), One
);
248 EXPECT_EQ(Test
, Ref
);
249 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
251 // branch + sub + br + sdiv + br
252 Ref
= getInstCost(Branch
) +
253 getInstCost(Sub
, /*SizeOnly =*/ true) +
254 getInstCost(BrBB1BB2
) +
255 getInstCost(Sdiv
, /*SizeOnly =*/ true) +
256 getInstCost(BrBB2
, /*SizeOnly =*/ true) +
257 getInstCost(BrLoop
, /*SizeOnly =*/ true);
258 Test
= Visitor
.getSpecializationBonus(F
->getArg(2), False
);
259 EXPECT_EQ(Test
, Ref
);
260 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
263 TEST_F(FunctionSpecializationTest
, Misc
) {
264 const char *ModuleString
= R
"(
265 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
266 @g = constant %struct_t zeroinitializer, align 16
268 declare i32 @llvm.smax.i32(i32, i32)
269 declare i32 @bar(i32)
271 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
272 %cmp = icmp eq i8 %a, 10
273 %ext = zext i1 %cmp to i64
274 %sel = select i1 %cond, i64 %ext, i64 1
275 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
276 %ld = load i32, ptr %gep
278 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
279 %call = call i32 @bar(i32 %smax)
281 %add = add i32 %call, %fr2
286 Module
&M
= parseModule(ModuleString
);
287 Function
*F
= M
.getFunction("foo");
288 FunctionSpecializer Specializer
= getSpecializerFor(F
);
289 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
291 GlobalVariable
*GV
= M
.getGlobalVariable("g");
292 Constant
*One
= ConstantInt::get(IntegerType::getInt8Ty(M
.getContext()), 1);
293 Constant
*True
= ConstantInt::getTrue(M
.getContext());
294 Constant
*Undef
= UndefValue::get(IntegerType::getInt32Ty(M
.getContext()));
296 auto BlockIter
= F
->front().begin();
297 Instruction
&Icmp
= *BlockIter
++;
298 Instruction
&Zext
= *BlockIter
++;
299 Instruction
&Select
= *BlockIter
++;
300 Instruction
&Gep
= *BlockIter
++;
301 Instruction
&Load
= *BlockIter
++;
302 Instruction
&Freeze
= *BlockIter
++;
303 Instruction
&Smax
= *BlockIter
++;
306 Bonus Ref
= getInstCost(Icmp
) + getInstCost(Zext
);
307 Bonus Test
= Visitor
.getSpecializationBonus(F
->getArg(0), One
);
308 EXPECT_EQ(Test
, Ref
);
309 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
312 Ref
= getInstCost(Select
);
313 Test
= Visitor
.getSpecializationBonus(F
->getArg(1), True
);
314 EXPECT_EQ(Test
, Ref
);
315 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
317 // gep + load + freeze + smax
318 Ref
= getInstCost(Gep
) + getInstCost(Load
) + getInstCost(Freeze
) +
320 Test
= Visitor
.getSpecializationBonus(F
->getArg(2), GV
);
321 EXPECT_EQ(Test
, Ref
);
322 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
324 Test
= Visitor
.getSpecializationBonus(F
->getArg(3), Undef
);
325 EXPECT_TRUE(Test
.CodeSize
== 0 && Test
.Latency
== 0);
328 TEST_F(FunctionSpecializationTest
, PhiNode
) {
329 const char *ModuleString
= R
"(
330 define void @foo(i32 %a, i32 %b, i32 %i) {
334 %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
335 switch i32 %i, label %default
336 [ i32 1, label %case1
337 i32 2, label %case2 ]
342 %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
345 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
346 %4 = icmp eq i32 %3, 1
347 br i1 %4, label %bb, label %loop
353 Module
&M
= parseModule(ModuleString
);
354 Function
*F
= M
.getFunction("foo");
355 FunctionSpecializer Specializer
= getSpecializerFor(F
);
356 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
358 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
360 auto FuncIter
= F
->begin();
361 BasicBlock
&Loop
= *++FuncIter
;
362 BasicBlock
&Case1
= *++FuncIter
;
363 BasicBlock
&Case2
= *++FuncIter
;
364 BasicBlock
&BB
= *++FuncIter
;
366 Instruction
&PhiLoop
= Loop
.front();
367 Instruction
&Switch
= Loop
.back();
368 Instruction
&Add
= Case1
.front();
369 Instruction
&PhiCase2
= Case2
.front();
370 Instruction
&BrBB
= Case2
.back();
371 Instruction
&PhiBB
= BB
.front();
372 Instruction
&Icmp
= *++BB
.begin();
373 Instruction
&Branch
= BB
.back();
375 Bonus Test
= Visitor
.getSpecializationBonus(F
->getArg(0), One
);
376 EXPECT_TRUE(Test
.CodeSize
== 0 && Test
.Latency
== 0);
378 Test
= Visitor
.getSpecializationBonus(F
->getArg(1), One
);
379 EXPECT_TRUE(Test
.CodeSize
== 0 && Test
.Latency
== 0);
382 Bonus Ref
= getInstCost(Switch
) +
383 getInstCost(PhiCase2
, /*SizeOnly =*/ true) +
384 getInstCost(BrBB
, /*SizeOnly =*/ true);
385 Test
= Visitor
.getSpecializationBonus(F
->getArg(2), One
);
386 EXPECT_EQ(Test
, Ref
);
387 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
389 // phi + phi + add + icmp + branch
390 Ref
= getInstCost(PhiBB
) + getInstCost(PhiLoop
) + getInstCost(Add
) +
391 getInstCost(Icmp
) + getInstCost(Branch
);
392 Test
= Visitor
.getBonusFromPendingPHIs();
393 EXPECT_EQ(Test
, Ref
);
394 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);