1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
21 #include "llvm/Transforms/Utils/SCCPSolver.h"
22 #include "gtest/gtest.h"
27 static void removeSSACopy(Function
&F
) {
28 for (BasicBlock
&BB
: F
) {
29 for (Instruction
&Inst
: llvm::make_early_inc_range(BB
)) {
30 if (auto *II
= dyn_cast
<IntrinsicInst
>(&Inst
)) {
31 if (II
->getIntrinsicID() != Intrinsic::ssa_copy
)
33 Inst
.replaceAllUsesWith(II
->getOperand(0));
34 Inst
.eraseFromParent();
40 class FunctionSpecializationTest
: public testing::Test
{
43 FunctionAnalysisManager FAM
;
44 std::unique_ptr
<Module
> M
;
45 std::unique_ptr
<SCCPSolver
> Solver
;
47 FunctionSpecializationTest() {
48 FAM
.registerPass([&] { return TargetLibraryAnalysis(); });
49 FAM
.registerPass([&] { return TargetIRAnalysis(); });
50 FAM
.registerPass([&] { return BlockFrequencyAnalysis(); });
51 FAM
.registerPass([&] { return BranchProbabilityAnalysis(); });
52 FAM
.registerPass([&] { return LoopAnalysis(); });
53 FAM
.registerPass([&] { return AssumptionAnalysis(); });
54 FAM
.registerPass([&] { return DominatorTreeAnalysis(); });
55 FAM
.registerPass([&] { return PostDominatorTreeAnalysis(); });
56 FAM
.registerPass([&] { return PassInstrumentationAnalysis(); });
59 Module
&parseModule(const char *ModuleString
) {
61 M
= parseAssemblyString(ModuleString
, Err
, Ctx
);
66 FunctionSpecializer
getSpecializerFor(Function
*F
) {
67 auto GetTLI
= [this](Function
&F
) -> const TargetLibraryInfo
& {
68 return FAM
.getResult
<TargetLibraryAnalysis
>(F
);
70 auto GetTTI
= [this](Function
&F
) -> TargetTransformInfo
& {
71 return FAM
.getResult
<TargetIRAnalysis
>(F
);
73 auto GetAC
= [this](Function
&F
) -> AssumptionCache
& {
74 return FAM
.getResult
<AssumptionAnalysis
>(F
);
76 auto GetDT
= [this](Function
&F
) -> DominatorTree
& {
77 return FAM
.getResult
<DominatorTreeAnalysis
>(F
);
79 auto GetBFI
= [this](Function
&F
) -> BlockFrequencyInfo
& {
80 return FAM
.getResult
<BlockFrequencyAnalysis
>(F
);
83 Solver
= std::make_unique
<SCCPSolver
>(M
->getDataLayout(), GetTLI
, Ctx
);
85 DominatorTree
&DT
= GetDT(*F
);
86 AssumptionCache
&AC
= GetAC(*F
);
87 Solver
->addPredicateInfo(*F
, DT
, AC
);
89 Solver
->markBlockExecutable(&F
->front());
90 for (Argument
&Arg
: F
->args())
91 Solver
->markOverdefined(&Arg
);
92 Solver
->solveWhileResolvedUndefsIn(*M
);
96 return FunctionSpecializer(*Solver
, *M
, &FAM
, GetBFI
, GetTLI
, GetTTI
,
100 Bonus
getInstCost(Instruction
&I
, bool SizeOnly
= false) {
101 auto &TTI
= FAM
.getResult
<TargetIRAnalysis
>(*I
.getFunction());
102 auto &BFI
= FAM
.getResult
<BlockFrequencyAnalysis
>(*I
.getFunction());
105 TTI
.getInstructionCost(&I
, TargetTransformInfo::TCK_CodeSize
);
110 : BFI
.getBlockFreq(I
.getParent()).getFrequency() /
111 BFI
.getEntryFreq().getFrequency() *
112 TTI
.getInstructionCost(&I
, TargetTransformInfo::TCK_Latency
);
114 return {CodeSize
, Latency
};
120 using namespace llvm
;
122 TEST_F(FunctionSpecializationTest
, SwitchInst
) {
123 const char *ModuleString
= R
"(
124 define void @foo(i32 %a, i32 %b, i32 %i) {
128 switch i32 %i, label %default
129 [ i32 1, label %case1
130 i32 2, label %case2 ]
150 Module
&M
= parseModule(ModuleString
);
151 Function
*F
= M
.getFunction("foo");
152 FunctionSpecializer Specializer
= getSpecializerFor(F
);
153 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
155 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
157 auto FuncIter
= F
->begin();
158 BasicBlock
&Loop
= *++FuncIter
;
159 BasicBlock
&Case1
= *++FuncIter
;
160 BasicBlock
&Case2
= *++FuncIter
;
161 BasicBlock
&BB1
= *++FuncIter
;
162 BasicBlock
&BB2
= *++FuncIter
;
164 Instruction
&Switch
= Loop
.front();
165 Instruction
&Mul
= Case1
.front();
166 Instruction
&And
= Case2
.front();
167 Instruction
&Sdiv
= *++Case2
.begin();
168 Instruction
&BrBB2
= Case2
.back();
169 Instruction
&Add
= BB1
.front();
170 Instruction
&Or
= BB2
.front();
171 Instruction
&BrLoop
= BB2
.back();
174 Bonus Ref
= getInstCost(Mul
);
175 Bonus Test
= Visitor
.getSpecializationBonus(F
->getArg(0), One
);
176 EXPECT_EQ(Test
, Ref
);
177 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
180 Ref
= getInstCost(And
) + getInstCost(Or
) + getInstCost(Add
);
181 Test
= Visitor
.getSpecializationBonus(F
->getArg(1), One
);
182 EXPECT_EQ(Test
, Ref
);
183 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
185 // switch + sdiv + br + br
186 Ref
= getInstCost(Switch
) +
187 getInstCost(Sdiv
, /*SizeOnly =*/ true) +
188 getInstCost(BrBB2
, /*SizeOnly =*/ true) +
189 getInstCost(BrLoop
, /*SizeOnly =*/ true);
190 Test
= Visitor
.getSpecializationBonus(F
->getArg(2), One
);
191 EXPECT_EQ(Test
, Ref
);
192 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
195 TEST_F(FunctionSpecializationTest
, BranchInst
) {
196 const char *ModuleString
= R
"(
197 define void @foo(i32 %a, i32 %b, i1 %cond) {
201 br i1 %cond, label %bb0, label %bb3
205 br i1 %cond, label %bb1, label %bb2
217 Module
&M
= parseModule(ModuleString
);
218 Function
*F
= M
.getFunction("foo");
219 FunctionSpecializer Specializer
= getSpecializerFor(F
);
220 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
222 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
223 Constant
*False
= ConstantInt::getFalse(M
.getContext());
225 auto FuncIter
= F
->begin();
226 BasicBlock
&Loop
= *++FuncIter
;
227 BasicBlock
&BB0
= *++FuncIter
;
228 BasicBlock
&BB1
= *++FuncIter
;
229 BasicBlock
&BB2
= *++FuncIter
;
231 Instruction
&Branch
= Loop
.front();
232 Instruction
&Mul
= BB0
.front();
233 Instruction
&Sub
= *++BB0
.begin();
234 Instruction
&BrBB1BB2
= BB0
.back();
235 Instruction
&Add
= BB1
.front();
236 Instruction
&Sdiv
= *++BB1
.begin();
237 Instruction
&BrBB2
= BB1
.back();
238 Instruction
&BrLoop
= BB2
.front();
241 Bonus Ref
= getInstCost(Mul
);
242 Bonus Test
= Visitor
.getSpecializationBonus(F
->getArg(0), One
);
243 EXPECT_EQ(Test
, Ref
);
244 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
247 Ref
= getInstCost(Add
);
248 Test
= Visitor
.getSpecializationBonus(F
->getArg(1), One
);
249 EXPECT_EQ(Test
, Ref
);
250 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
252 // branch + sub + br + sdiv + br
253 Ref
= getInstCost(Branch
) +
254 getInstCost(Sub
, /*SizeOnly =*/ true) +
255 getInstCost(BrBB1BB2
) +
256 getInstCost(Sdiv
, /*SizeOnly =*/ true) +
257 getInstCost(BrBB2
, /*SizeOnly =*/ true) +
258 getInstCost(BrLoop
, /*SizeOnly =*/ true);
259 Test
= Visitor
.getSpecializationBonus(F
->getArg(2), False
);
260 EXPECT_EQ(Test
, Ref
);
261 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
264 TEST_F(FunctionSpecializationTest
, Misc
) {
265 const char *ModuleString
= R
"(
266 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
267 @g = constant %struct_t zeroinitializer, align 16
269 declare i32 @llvm.smax.i32(i32, i32)
270 declare i32 @bar(i32)
272 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
273 %cmp = icmp eq i8 %a, 10
274 %ext = zext i1 %cmp to i64
275 %sel = select i1 %cond, i64 %ext, i64 1
276 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
277 %ld = load i32, ptr %gep
279 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
280 %call = call i32 @bar(i32 %smax)
282 %add = add i32 %call, %fr2
287 Module
&M
= parseModule(ModuleString
);
288 Function
*F
= M
.getFunction("foo");
289 FunctionSpecializer Specializer
= getSpecializerFor(F
);
290 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
292 GlobalVariable
*GV
= M
.getGlobalVariable("g");
293 Constant
*One
= ConstantInt::get(IntegerType::getInt8Ty(M
.getContext()), 1);
294 Constant
*True
= ConstantInt::getTrue(M
.getContext());
295 Constant
*Undef
= UndefValue::get(IntegerType::getInt32Ty(M
.getContext()));
297 auto BlockIter
= F
->front().begin();
298 Instruction
&Icmp
= *BlockIter
++;
299 Instruction
&Zext
= *BlockIter
++;
300 Instruction
&Select
= *BlockIter
++;
301 Instruction
&Gep
= *BlockIter
++;
302 Instruction
&Load
= *BlockIter
++;
303 Instruction
&Freeze
= *BlockIter
++;
304 Instruction
&Smax
= *BlockIter
++;
307 Bonus Ref
= getInstCost(Icmp
) + getInstCost(Zext
);
308 Bonus Test
= Visitor
.getSpecializationBonus(F
->getArg(0), One
);
309 EXPECT_EQ(Test
, Ref
);
310 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
313 Ref
= getInstCost(Select
);
314 Test
= Visitor
.getSpecializationBonus(F
->getArg(1), True
);
315 EXPECT_EQ(Test
, Ref
);
316 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
318 // gep + load + freeze + smax
319 Ref
= getInstCost(Gep
) + getInstCost(Load
) + getInstCost(Freeze
) +
321 Test
= Visitor
.getSpecializationBonus(F
->getArg(2), GV
);
322 EXPECT_EQ(Test
, Ref
);
323 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
325 Test
= Visitor
.getSpecializationBonus(F
->getArg(3), Undef
);
326 EXPECT_TRUE(Test
.CodeSize
== 0 && Test
.Latency
== 0);
329 TEST_F(FunctionSpecializationTest
, PhiNode
) {
330 const char *ModuleString
= R
"(
331 define void @foo(i32 %a, i32 %b, i32 %i) {
335 %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
336 switch i32 %i, label %default
337 [ i32 1, label %case1
338 i32 2, label %case2 ]
343 %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
346 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
347 %4 = icmp eq i32 %3, 1
348 br i1 %4, label %bb, label %loop
354 Module
&M
= parseModule(ModuleString
);
355 Function
*F
= M
.getFunction("foo");
356 FunctionSpecializer Specializer
= getSpecializerFor(F
);
357 InstCostVisitor Visitor
= Specializer
.getInstCostVisitorFor(F
);
359 Constant
*One
= ConstantInt::get(IntegerType::getInt32Ty(M
.getContext()), 1);
361 auto FuncIter
= F
->begin();
362 BasicBlock
&Loop
= *++FuncIter
;
363 BasicBlock
&Case1
= *++FuncIter
;
364 BasicBlock
&Case2
= *++FuncIter
;
365 BasicBlock
&BB
= *++FuncIter
;
367 Instruction
&PhiLoop
= Loop
.front();
368 Instruction
&Switch
= Loop
.back();
369 Instruction
&Add
= Case1
.front();
370 Instruction
&PhiCase2
= Case2
.front();
371 Instruction
&BrBB
= Case2
.back();
372 Instruction
&PhiBB
= BB
.front();
373 Instruction
&Icmp
= *++BB
.begin();
374 Instruction
&Branch
= BB
.back();
376 Bonus Test
= Visitor
.getSpecializationBonus(F
->getArg(0), One
);
377 EXPECT_TRUE(Test
.CodeSize
== 0 && Test
.Latency
== 0);
379 Test
= Visitor
.getSpecializationBonus(F
->getArg(1), One
);
380 EXPECT_TRUE(Test
.CodeSize
== 0 && Test
.Latency
== 0);
383 Bonus Ref
= getInstCost(Switch
) +
384 getInstCost(PhiCase2
, /*SizeOnly =*/ true) +
385 getInstCost(BrBB
, /*SizeOnly =*/ true);
386 Test
= Visitor
.getSpecializationBonus(F
->getArg(2), One
);
387 EXPECT_EQ(Test
, Ref
);
388 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);
390 // phi + phi + add + icmp + branch
391 Ref
= getInstCost(PhiBB
) + getInstCost(PhiLoop
) + getInstCost(Add
) +
392 getInstCost(Icmp
) + getInstCost(Branch
);
393 Test
= Visitor
.getBonusFromPendingPHIs();
394 EXPECT_EQ(Test
, Ref
);
395 EXPECT_TRUE(Test
.CodeSize
> 0 && Test
.Latency
> 0);